In [127]:
import pandas as pd
from random import randint, choice
from scipy.stats import binom
from numpy.random import choice
from tqdm.notebook import tqdm

In [128]:
dt = pd.read_csv('forecast1.csv', sep=';')
fishs = ['щука', 'судак', 'окунь', 'берш', 'речная форель', 'озерная форель', 'елец', 'чехонь', 'сом', 'голавль', 'язь',
         'карп', 'жерех', 'лещ', 'карась', 'линь', 'пескарь', 'ротан', 'плотва', 'красноперка', 'налим', 'густера',
         'амур', 'ерш', 'сазан', 'подуст', 'толстолобик', 'вобла', 'хариус']
dt = dt.rename(columns={fish: fish.capitalize() for fish in fishs})
dt

Unnamed: 0,day_temp,day_pressure,day_obl,day_phen,day_dir,day_wind,areal,city,year,month,...,Красноперка,Налим,Густера,Амур,Ерш,Сазан,Подуст,Толстолобик,Вобла,Хариус
0,-4,749,dull,snow,Ю,3,Алтайский край,Барнаул,2020,1,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
1,-1,750,dull,-,Ю,5,Алтайский край,Барнаул,2020,1,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
2,-3,749,dull,-,Ю,2,Алтайский край,Барнаул,2020,1,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
3,-6,753,dull,-,Ю,1,Алтайский край,Барнаул,2020,1,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
4,-2,752,suncl,-,Ю,3,Алтайский край,Барнаул,2020,1,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30516,-6,752,dull,-,З,1,Московская область,Щелково,2020,12,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
30517,-6,759,sunc,-,ЮВ,1,Московская область,Щелково,2020,12,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
30518,-5,754,dull,-,ЮВ,3,Московская область,Щелково,2020,12,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0
30519,-3,753,dull,-,-,0,Московская область,Щелково,2020,12,...,-1,-1.0,-1,-1,-1,-1.0,-1,-1,-1,-1.0


In [129]:
class DayGenerator:
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.times = [12, 15, 18, 21, 0, 3, 6, 9]
        self.dirs = {
            'С': {'values': ['СЗ', 'СВ'], 'p': [0.2, 0.3]},
            'СЗ': {'values': ['С', 'З'], 'p': [0.25, 0.25]},
            'З': {'values': ['СЗ', 'ЮЗ'], 'p': [0.2, 0.3]},
            'ЮЗ': {'values': ['З', 'Ю'], 'p': [0.25, 0.25]},
            'Ю': {'values': ['ЮЗ', 'ЮВ'], 'p': [0.3, 0.2]},
            'ЮВ': {'values': ['Ю', 'В'], 'p': [0.2, 0.3]},
            'В': {'values': ['СВ', 'ЮВ'], 'p': [0.3, 0.2]},
            'СВ': {'values': ['С', 'В'], 'p': [0.3, 0.2]},
        }
        
        self.phens = {
            'snow': {'values': ['небольшой снег', 'снег', 'снег с дождём', 'сильный снег', 'мокрый снег'], 'p': [0.35, 0.25, 0.15, 0.1, 0.15]},
            'rain': {'values': ['небольшой дождь', 'дождь', 'сильный дождь'], 'p': [0.45, 0.35, 0.2]},
            'storm': {'values': ['небольшой дождь', 'дождь', 'гроза', 'сильный дождь'], 'p': [0.15, 0.25, 0.3, 0.3]},
        }
        
        self.obl = {
            'dull': {'values': ['пасмурно', 'облачно'], 'p': [0.7, 0.3]},
            'suncl': {'values': ['облачно', 'малооблачно', 'пасмурно'], 'p': [0.6, 0.3, 0.1]},
            'sun': {'values': ['ясно', 'малооблачно'], 'p': [0.7, 0.3]},
            'sunc': {'values': ['малооблачно', 'ясно', 'облачно'], 'p': [0.6, 0.2, 0.2]},
        }
        
        self.fishs = ['Щука', 'Судак', 'Окунь', 'Берш', 'Речная форель', 'Озерная форель', 'Елец', 'Чехонь', 'Сом', 'Голавль', 'Язь',
         'Карп', 'Жерех', 'Лещ', 'Карась', 'Линь', 'Пескарь', 'Ротан', 'Плотва', 'Красноперка', 'Налим', 'Густера',
         'Амур', 'Ерш', 'Сазан', 'Подуст', 'Толстолобик', 'Вобла', 'Хариус']
        
    def __getitem__(self, idx):
        day_priew = self.dataframe.loc[[idx - 1]]
        day = self.dataframe.loc[[idx]]
        day_next = self.dataframe.loc[[idx + 1]]
        pressure = ','.join(map(str, self.gen_pressure(day_priew['day_pressure'].item(), day['day_pressure'].item(), day_next['day_pressure'].item())))
        temperature = ','.join(map(str, self.gen_temperature(day_priew['day_temp'].item(), day['day_temp'].item(), day_next['day_temp'].item())))
        wind_, gust_ = self.gen_wind(day_priew['day_wind'].item(), day['day_wind'].item(), day_next['day_wind'].item())
        wind = ','.join(map(str, wind_))
        gust = ','.join(map(str, gust_))
        wind_direction = ','.join(self.gen_dir(day['day_dir'].item()))
        phenomen_= self.gen_phenomen(day['day_obl'].item(), day['day_phen'].item())
        humidity = ','.join(map(str,self.gen_hum(phenomen_)))
        phenomen = ','.join(phenomen_)
        uv_index = ','.join(map(str,self.gen_uv(day['month'].item())))
        moon_direction, moon = self.gen_moon(sum(day[fish].item() for fish in self.fishs) / len(self.fishs))
        return {
            'pressure': pressure,
            'temperature': temperature,
            'wind': wind,
            'gust': gust,
            'wind_direction': wind_direction,
            'humidity': humidity,
            'phenomenon': phenomen,
            'uv_index': uv_index,
            'moon_direction': moon_direction,
            'moon': moon,
            'month': day['month'].item(),
            'day': day['day'].item(),
            
        }
        
    def dist_(self, left, right, len_):
        is_revert = False
        if left > right:
            left, right = right, left
            is_revert = True
        result = []
        current = left
        for _ in range(len_ - 1): 
            sub = binom.rvs(right - current, 1 / len_)
            current += sub
            result.append(current)
        result.append(right)
        return result[::-1] if is_revert else result
        
    def gen_pressure(self, pressure_priew, pressure, pressure_next):
        pressure_ = {}
        priew_sub = pressure - pressure_priew
        next_sub = pressure_next - pressure
        current_pressure = pressure_priew
        for time in [15, 18, 21]:
            sub = randint(-2, 0) if priew_sub < 0 else randint(-1, 1)
            current_pressure += sub
        dist = self.dist_(current_pressure, pressure, 5)
        for idx, time in enumerate([0, 3, 6, 9, 12]):
            pressure_.update({time: dist[idx]})
        current_pressure = pressure
        for time in [15, 18, 21]:
            sub = randint(-2, 0) if next_sub < 0 else randint(-1, 1)
            current_pressure += sub
            pressure_.update({time: current_pressure})
        return [pressure_[key] for key in sorted(pressure_)]
    
    def gen_wind(self, wind_priew, wind, wind_next):
        wind_ = {}
        priew_sub = wind - wind_priew
        next_sub = wind_next - wind
        current_wind = randint(0, randint(wind_priew, wind)) if priew_sub > 0 else 0
        dist = self.dist_(current_wind, wind, 5)
        for idx, time in enumerate([0, 3, 6, 9, 12]):
            wind_.update({time: dist[idx]})
        current_wind = wind_[12] + randint(-1, 2)
        current_wind = max(0, current_wind)
        wind_.update({15: current_wind})
        current_wind += randint(-1, 2)
        current_wind = max(0, current_wind)
        wind_.update({18: current_wind})
        current_wind += randint(-3, 0)
        current_wind = max(0, current_wind)
        wind_.update({21: current_wind})
        gust_= {time: wind_[time] + randint(3, 8) for time in self.times}
        return [wind_[key] for key in sorted(wind_)], [gust_[key] for key in sorted(gust_)]
    
    def gen_dir(self,direction):
        dir_ = {}
        if direction == '-':
            direction = choice(list(self.dirs.keys()))
        dist = list(choice([direction] + self.dirs[direction]['values'], 8, p=[0.5] + self.dirs[direction]['p']))
        for idx, time in enumerate(self.times):
            dir_.update({time: dist[idx]})
        return [dir_[key] for key in sorted(dir_)]
    
    def gen_temperature(self, temp_priew, temp, temp_next):
        temp_ = {}
        priew_sub = temp - temp_priew
        next_sub = temp_next - temp
        current_temp = temp_priew
        for time in [15, 18, 21]:
            sub = randint(-3, -1) if priew_sub < 0 else randint(-2, 0)
            current_temp += sub
        dist = self.dist_(current_temp, temp, 5)
        for idx, time in enumerate([0, 3, 6, 9, 12]):
            temp_.update({time: dist[idx]})
        current_temp = temp
        for time in [15, 18, 21]:
            sub = randint(-3, -1) if next_sub < 0 else randint(-2, 0)
            current_temp += sub
            temp_.update({time: current_temp})
        return [temp_[key] for key in sorted(temp_)]
    
    def gen_phenomen(self, obl, phen):
        phens_ = {}
        if obl == '-':
            obl = 'sun'
        dist_olb = list(choice(self.obl[obl]['values'], 8, p=self.obl[obl]['p']))
        if phen == '-':
            for idx, time in enumerate(self.times):
                phens_.update({time: dist_olb[idx]})
            return [phens_[key] for key in sorted(phens_)] 
        dist_phens = list(choice(self.phens[phen]['values'], 8, p=self.phens[phen]['p']))
        for idx, time in enumerate(self.times):
            phens_.update({time: '.'.join((dist_olb[idx], dist_phens[idx]))})
        return [phens_[key] for key in sorted(phens_)]
    
    def gen_hum(self, phens):
        hum_ = []
        current_hum = randint(30, 70)
        for phen in phens:
            if 'дождь' in phen:
                current_hum += randint(0, 20)
                current_hum = min(current_hum, randint(93, 99))
            else:
                current_hum += randint(-10, 5)
                current_hum = max(current_hum, randint(10, 29))
            hum_.append(current_hum)
        return hum_
    
    def gen_uv(self, month):
        uv_ = {}
        if month in [12, 1, 2]:
            uv_ = {0: 0, 3: 0, 6: 0, 9: 0, 12: randint(1, 2), 15: 1, 18: 0, 21:0} 
        elif month in [3, 11, 10, 4]:
            uv_ = {0: 0, 3: 0, 6: 0, 9: 1, 12: randint(2, 3), 15: randint(1, 2), 18: 1, 21: 0}
        elif month in [5, 9]:
            uv_ = {0: 0, 3: 0, 6: randint(0, 1), 9: randint(1, 2), 12: randint(2, 4), 15: randint(2, 3), 18: randint(1, 2), 21: randint(0, 1)}
        else:
            uv_ = {0: 0, 3: 0, 6: randint(1, 2), 9: randint(1, 3), 12: randint(2, 4), 15: randint(1, 3), 18: randint(1, 3), 21: randint(1, 2)}
        return [uv_[key] for key in sorted(uv_)]
    
    def gen_moon(self, forecast):
        if forecast > 0.5:
            return 2* randint(0, 1) - 1, randint(35, 70)
        else:
            if randint(0, 1):
                return 2* randint(0, 1) - 1, randint(0, 35)
            else:
                return 2* randint(0, 1) - 1, randint(70, 99)   

In [158]:
ALONE_KEYS = {'time', 'day', 'month', 'humidity', 'uv_index', 'moon', 'moon_direction'}
DIGIT_KEYS = {'temperature', 'wind', 'gust', 'pressure', 'humidity', 'uv_index'}
CATEGORY_KEYS = {'phenomenon', 'wind_direction'}
MOON_KEYS = {'moon', 'moon_direction'}
SUN_KEYS = {'sun_up', 'sun_down'}
WIND_DIRECTIONS = ['Ю', 'ЮЗ', 'З', 'СЗ', 'С', 'СВ', 'В', 'ЮВ']
PHENOMENONS = ['ясно', 'малооблачно', 'облачно', 'пасмурно', 'небольшой дождь', 'дождь', 'сильный дождь',
               'небольшой снег', 'снег', 'снег с дождём', 'сильный снег', 'гроза', 'мокрый снег']
REGIONS = ['Алтайский край', 'Амурская область', 'Архангельская область', 'Астраханская область',
           'Белгородская область', 'Брянская область', 'Владимирская область', 'Волгоградская область',
           'Вологодская область', 'Воронежская область', 'Еврейская автономная область', 'Забайкальский край',
           'Ивановская область', 'Иркутская область', 'Кабардино-Балкарская республика', 'Калининградская область',
           'Калужская область', 'Камчатский край', 'Карачаево-Черкесская республика', 'Кемеровская область',
           'Кировская область', 'Костромская область', 'Краснодарский край', 'Красноярский край', 'Курганская область',
           'Курская область', 'Ленинградская область', 'Липецкая область', 'Магаданская область', 'Московская область']


num_hours = 8
num_days = 3

def preprocess_(data):
    all_data = {}
    for d in data:
        for key in DIGIT_KEYS:
            temp = list(map(int, d[key].split(',')))
            if key in all_data:
                all_data[key] = all_data[key] + temp
            else:
                all_data[key] = temp
        for key in MOON_KEYS:
            temp = [d[key] for _ in range(num_hours)]
            if key in all_data:
                all_data[key] = all_data[key] + temp
            else:
                all_data[key] = temp
        for key in CATEGORY_KEYS:
            if key in all_data:
                all_data[key] = all_data[key] + d[key].split(',')
            else:
                all_data[key] = d[key].split(',')
        days = [d['day'] for _ in range(num_hours)]
        months = [d['month'] for _ in range(num_hours)]
        if 'day' in all_data:
            all_data['day'] = all_data['day'] + days
        else:
            all_data['day'] = days
        if 'month' in all_data:
            all_data['month'] = all_data['month'] + months
        else:
            all_data['month'] = months
        if 'time' in all_data:
            all_data['time'] = all_data['time'] + list(range(0, num_hours * 3, 3))
        else:
            all_data['time'] = list(range(0, num_hours * 3, 3))
    return all_data

def slice_(data, left_bound, righ_bound):
    slice_data = {
        key: data[key][left_bound: righ_bound] for key in data
    }
    return slice_data

def preprocess_batch_(data):
    vec = {}
    for key in data:
        if key in DIGIT_KEYS and not key in ALONE_KEYS:
            for i in range(len(data[key])):
                key_name = '{}_{}'.format(key, i)
                vec.update({key_name: data[key][i]})
        elif key == 'phenomenon':
            phenomenons_ = [_.split('.') for _ in data[key]]
            for phenomenon in PHENOMENONS:
                for i in range(len(phenomenons_)):
                    key_name = '{}_{}'.format(phenomenon, i)
                    vec.update({key_name: int(phenomenon in phenomenons_[i])})
        elif key == 'wind_direction':
            for wind_direction in WIND_DIRECTIONS:
                for i in range(len(data[key])):
                    key_name = '{}_{}'.format(wind_direction, i)
                    vec.update({key_name: int(wind_direction == data[key][i])})
        elif key == 'month':
            for month in range(1, 13):
                key_name = 'month_{}'.format(month)
                vec.update({key_name: int(month == data[key][-1])})
        elif key == 'time':
            for time in range(0, 24, 3):
                key_name = 'time_{}'.format(time)
                vec.update({key_name: int(time == data[key][-1])})
        elif key == 'moon_direction':
            for moon_direction in [-1, 1]:
                key_name = 'moon_direction_{}'.format(moon_direction)
                vec.update({key_name: int(moon_direction == data[key][-1])})
        elif key in ALONE_KEYS:
            key_name = '{}'.format(key)
            vec.update({key_name: data[key][-1]})
    return vec

def gen_forecast(forecast, time, fish):
    if fish in ['Сом', 'Налим']:
        if time == 21:
            return int(forecast) | (randint(1, 100) <= 20)
        elif time in [0, 3]:
            return int(forecast) | (randint(1, 100) <= 30)
        elif time == 6:
            return int(forecast) | (randint(1, 100) <= 20)
        else:
            return int(forecast) * (randint(1, 100) <= 40)
    else:
        if time in [6, 18]:
            return int(forecast) | (randint(1, 100) <= 40)
        elif time == 9:
            return int(forecast) | (randint(1, 100) <= 30)
        elif time == 12:
            return int(forecast) | (randint(1, 100) <= 10)
        elif time == 15:
            return int(forecast)
        else:
            return int(forecast) * (randint(1, 100) <= 30)
            

In [159]:
gen = DayGenerator(dt)

In [160]:
train_data = []
for idx, row in tqdm(dt.iterrows()):
    if row['Щука'] > -1:
        data = dt[idx - 3: idx + 1]
        day = gen[idx]
        day_1 = gen[idx - 1]
        day_2 = gen[idx - 2]
        day_3 = gen[idx - 3]
        all_data = preprocess_([day_3, day_2, day_1, day])
#         print(all_data)
#         break
        len_data = len(all_data['moon'])
        probs = {fish: [] for fish in gen.fishs}
        for fish in gen.fishs:
            for i in range(num_hours * num_days, len_data + 1):
                slice_data = slice_(all_data, i - num_hours * num_days, i)
                vec = preprocess_batch_(slice_data)
                for fish_ in gen.fishs:
                    vec.update({fish_: int(fish_ == fish)})
                time = 0
                for j in range(0, 24, 3):
                    if vec['time_{}'.format(j)] == 1:
                        time = j
                        break
                vec = {key: vec[key] for key in sorted(vec)}
                forecast = gen_forecast(row[fish], time, fish)
                vec.update({'forecast': forecast})
                train_data.append(vec)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [161]:
len(train_data)

26883

In [162]:
train_dt = pd.DataFrame(train_data)

In [163]:
fishs_dt = [train_dt[train_dt[fish.capitalize()] == 1] for fish in fishs]

In [164]:
y_means = {fishs[i]: fishs_dt[i]['forecast'].mean() for i in range(len(fishs))}
y_means

{'щука': 0.4670981661272923,
 'судак': 0.4002157497303128,
 'окунь': 0.46601941747572817,
 'берш': 0.4153182308522114,
 'речная форель': 0.43581445523193096,
 'озерная форель': 0.29449838187702265,
 'елец': 0.3149946062567422,
 'чехонь': 0.284789644012945,
 'сом': 0.2761596548004315,
 'голавль': 0.39158576051779936,
 'язь': 0.35490830636461707,
 'карп': 0.2686084142394822,
 'жерех': 0.2988133764832794,
 'лещ': 0.39805825242718446,
 'карась': 0.3106796116504854,
 'линь': 0.28047464940668826,
 'пескарь': 0.37971952535059333,
 'ротан': 0.4692556634304207,
 'плотва': 0.44228694714131606,
 'красноперка': 0.4099244875943905,
 'налим': 0.33225458468176916,
 'густера': 0.48220064724919093,
 'амур': 0.2513484358144552,
 'ерш': 0.5415318230852212,
 'сазан': 0.2308522114347357,
 'подуст': 0.2977346278317152,
 'толстолобик': 0.2696871628910464,
 'вобла': 0.313915857605178,
 'хариус': 0.35490830636461707}

In [165]:
train_dt.describe()

Unnamed: 0,day,gust_0,gust_1,gust_10,gust_11,gust_12,gust_13,gust_14,gust_15,gust_16,...,ясно_22,ясно_23,ясно_3,ясно_4,ясно_5,ясно_6,ясно_7,ясно_8,ясно_9,forecast
count,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,...,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0
mean,16.200647,7.764833,7.797195,7.993528,8.00863,8.085221,8.143474,8.170442,7.961165,7.79288,...,0.256742,0.26753,0.206041,0.21575,0.222222,0.221143,0.225458,0.229773,0.221143,0.359781
std,8.85965,2.757161,2.736176,2.797794,2.805873,2.855859,2.902628,2.916561,2.836973,2.770724,...,0.436844,0.442679,0.404468,0.411349,0.415747,0.415025,0.417892,0.420695,0.415025,0.479945
min,1.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,9.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,16.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,24.0,9.0,9.0,10.0,10.0,10.0,10.0,10.0,10.0,9.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
max,31.0,22.0,22.0,20.0,20.0,20.0,20.0,20.0,20.0,20.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [166]:
# f1 = train_dt[train_dt['forecast'] == 1]
# f0 = train_dt[train_dt['forecast'] == 0].sample(len(f1))
# train_dt = pd.concat([f0, f1])

In [167]:
train_dt = train_dt.sample(frac=1.0).reset_index(drop=True)

In [168]:
y = train_dt['forecast']
del train_dt['forecast']

In [169]:
y.std()

0.47994507759235244

In [170]:
train_dt.describe()

Unnamed: 0,day,gust_0,gust_1,gust_10,gust_11,gust_12,gust_13,gust_14,gust_15,gust_16,...,ясно_21,ясно_22,ясно_23,ясно_3,ясно_4,ясно_5,ясно_6,ясно_7,ясно_8,ясно_9
count,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,...,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0,26883.0
mean,16.200647,7.764833,7.797195,7.993528,8.00863,8.085221,8.143474,8.170442,7.961165,7.79288,...,0.248112,0.256742,0.26753,0.206041,0.21575,0.222222,0.221143,0.225458,0.229773,0.221143
std,8.85965,2.757161,2.736176,2.797794,2.805873,2.855859,2.902628,2.916561,2.836973,2.770724,...,0.431925,0.436844,0.442679,0.404468,0.411349,0.415747,0.415025,0.417892,0.420695,0.415025
min,1.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,9.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,16.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,24.0,9.0,9.0,10.0,10.0,10.0,10.0,10.0,10.0,9.0,...,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,31.0,22.0,22.0,20.0,20.0,20.0,20.0,20.0,20.0,20.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [171]:
X = train_dt.values

In [172]:
from sklearn.model_selection import train_test_split

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)

In [173]:
X_train.shape

(21506, 655)

In [174]:
X_test.shape

(5377, 655)

In [175]:
from catboost import CatBoostClassifier

model = CatBoostClassifier()

In [176]:
model.fit(X_train, y_train)

Learning rate set to 0.038189
0:	learn: 0.6810519	total: 28.3ms	remaining: 28.3s
1:	learn: 0.6675217	total: 45.6ms	remaining: 22.7s
2:	learn: 0.6557397	total: 62.8ms	remaining: 20.9s
3:	learn: 0.6445059	total: 79.7ms	remaining: 19.9s
4:	learn: 0.6351122	total: 96.8ms	remaining: 19.3s
5:	learn: 0.6252901	total: 113ms	remaining: 18.7s
6:	learn: 0.6175708	total: 128ms	remaining: 18.1s
7:	learn: 0.6080811	total: 145ms	remaining: 18s
8:	learn: 0.6002262	total: 165ms	remaining: 18.1s
9:	learn: 0.5930597	total: 186ms	remaining: 18.4s
10:	learn: 0.5857254	total: 203ms	remaining: 18.2s
11:	learn: 0.5796055	total: 220ms	remaining: 18.1s
12:	learn: 0.5736042	total: 236ms	remaining: 17.9s
13:	learn: 0.5680124	total: 254ms	remaining: 17.9s
14:	learn: 0.5632838	total: 271ms	remaining: 17.8s
15:	learn: 0.5591314	total: 289ms	remaining: 17.8s
16:	learn: 0.5558794	total: 304ms	remaining: 17.6s
17:	learn: 0.5523958	total: 322ms	remaining: 17.6s
18:	learn: 0.5487519	total: 340ms	remaining: 17.5s
19:	lear

163:	learn: 0.4372691	total: 2.96s	remaining: 15.1s
164:	learn: 0.4370522	total: 2.98s	remaining: 15.1s
165:	learn: 0.4368931	total: 2.99s	remaining: 15s
166:	learn: 0.4365025	total: 3.01s	remaining: 15s
167:	learn: 0.4363296	total: 3.03s	remaining: 15s
168:	learn: 0.4359470	total: 3.05s	remaining: 15s
169:	learn: 0.4357573	total: 3.07s	remaining: 15s
170:	learn: 0.4355802	total: 3.08s	remaining: 15s
171:	learn: 0.4353778	total: 3.1s	remaining: 14.9s
172:	learn: 0.4350928	total: 3.12s	remaining: 14.9s
173:	learn: 0.4349585	total: 3.14s	remaining: 14.9s
174:	learn: 0.4347229	total: 3.16s	remaining: 14.9s
175:	learn: 0.4342050	total: 3.18s	remaining: 14.9s
176:	learn: 0.4338825	total: 3.19s	remaining: 14.9s
177:	learn: 0.4336147	total: 3.21s	remaining: 14.8s
178:	learn: 0.4333039	total: 3.23s	remaining: 14.8s
179:	learn: 0.4328884	total: 3.25s	remaining: 14.8s
180:	learn: 0.4327080	total: 3.27s	remaining: 14.8s
181:	learn: 0.4324398	total: 3.29s	remaining: 14.8s
182:	learn: 0.4321736	tot

328:	learn: 0.3944588	total: 5.96s	remaining: 12.2s
329:	learn: 0.3942061	total: 5.97s	remaining: 12.1s
330:	learn: 0.3939449	total: 5.99s	remaining: 12.1s
331:	learn: 0.3937122	total: 6.01s	remaining: 12.1s
332:	learn: 0.3934409	total: 6.03s	remaining: 12.1s
333:	learn: 0.3931826	total: 6.04s	remaining: 12s
334:	learn: 0.3928377	total: 6.06s	remaining: 12s
335:	learn: 0.3925881	total: 6.08s	remaining: 12s
336:	learn: 0.3924364	total: 6.09s	remaining: 12s
337:	learn: 0.3921593	total: 6.11s	remaining: 12s
338:	learn: 0.3919337	total: 6.13s	remaining: 11.9s
339:	learn: 0.3917376	total: 6.14s	remaining: 11.9s
340:	learn: 0.3915269	total: 6.16s	remaining: 11.9s
341:	learn: 0.3913146	total: 6.18s	remaining: 11.9s
342:	learn: 0.3911156	total: 6.2s	remaining: 11.9s
343:	learn: 0.3908658	total: 6.21s	remaining: 11.9s
344:	learn: 0.3906230	total: 6.23s	remaining: 11.8s
345:	learn: 0.3904264	total: 6.25s	remaining: 11.8s
346:	learn: 0.3901235	total: 6.27s	remaining: 11.8s
347:	learn: 0.3898509	t

489:	learn: 0.3631414	total: 8.72s	remaining: 9.08s
490:	learn: 0.3629637	total: 8.74s	remaining: 9.06s
491:	learn: 0.3627439	total: 8.76s	remaining: 9.04s
492:	learn: 0.3626309	total: 8.77s	remaining: 9.02s
493:	learn: 0.3624578	total: 8.79s	remaining: 9s
494:	learn: 0.3622871	total: 8.8s	remaining: 8.98s
495:	learn: 0.3620857	total: 8.82s	remaining: 8.96s
496:	learn: 0.3619006	total: 8.84s	remaining: 8.94s
497:	learn: 0.3617553	total: 8.85s	remaining: 8.92s
498:	learn: 0.3615799	total: 8.87s	remaining: 8.9s
499:	learn: 0.3614536	total: 8.88s	remaining: 8.88s
500:	learn: 0.3613010	total: 8.9s	remaining: 8.87s
501:	learn: 0.3611556	total: 8.92s	remaining: 8.85s
502:	learn: 0.3610318	total: 8.94s	remaining: 8.83s
503:	learn: 0.3608533	total: 8.95s	remaining: 8.81s
504:	learn: 0.3606487	total: 8.97s	remaining: 8.79s
505:	learn: 0.3605248	total: 8.98s	remaining: 8.77s
506:	learn: 0.3603630	total: 9s	remaining: 8.75s
507:	learn: 0.3601987	total: 9.02s	remaining: 8.73s
508:	learn: 0.3600121

654:	learn: 0.3389822	total: 11.5s	remaining: 6.06s
655:	learn: 0.3388312	total: 11.5s	remaining: 6.04s
656:	learn: 0.3387228	total: 11.5s	remaining: 6.03s
657:	learn: 0.3385466	total: 11.6s	remaining: 6.01s
658:	learn: 0.3383872	total: 11.6s	remaining: 5.99s
659:	learn: 0.3382191	total: 11.6s	remaining: 5.97s
660:	learn: 0.3380588	total: 11.6s	remaining: 5.95s
661:	learn: 0.3379946	total: 11.6s	remaining: 5.94s
662:	learn: 0.3378604	total: 11.6s	remaining: 5.92s
663:	learn: 0.3377404	total: 11.7s	remaining: 5.9s
664:	learn: 0.3375758	total: 11.7s	remaining: 5.88s
665:	learn: 0.3374710	total: 11.7s	remaining: 5.86s
666:	learn: 0.3373196	total: 11.7s	remaining: 5.85s
667:	learn: 0.3371962	total: 11.7s	remaining: 5.83s
668:	learn: 0.3371044	total: 11.7s	remaining: 5.81s
669:	learn: 0.3370056	total: 11.8s	remaining: 5.79s
670:	learn: 0.3368837	total: 11.8s	remaining: 5.77s
671:	learn: 0.3367496	total: 11.8s	remaining: 5.75s
672:	learn: 0.3366453	total: 11.8s	remaining: 5.74s
673:	learn: 0

813:	learn: 0.3195479	total: 15.1s	remaining: 3.44s
814:	learn: 0.3194327	total: 15.1s	remaining: 3.42s
815:	learn: 0.3193198	total: 15.1s	remaining: 3.4s
816:	learn: 0.3192204	total: 15.1s	remaining: 3.39s
817:	learn: 0.3191081	total: 15.1s	remaining: 3.37s
818:	learn: 0.3190270	total: 15.2s	remaining: 3.35s
819:	learn: 0.3189474	total: 15.2s	remaining: 3.33s
820:	learn: 0.3188501	total: 15.2s	remaining: 3.32s
821:	learn: 0.3187157	total: 15.2s	remaining: 3.3s
822:	learn: 0.3186050	total: 15.3s	remaining: 3.28s
823:	learn: 0.3185107	total: 15.3s	remaining: 3.27s
824:	learn: 0.3183607	total: 15.3s	remaining: 3.25s
825:	learn: 0.3182687	total: 15.3s	remaining: 3.23s
826:	learn: 0.3181301	total: 15.4s	remaining: 3.21s
827:	learn: 0.3180225	total: 15.4s	remaining: 3.2s
828:	learn: 0.3179249	total: 15.4s	remaining: 3.18s
829:	learn: 0.3178371	total: 15.4s	remaining: 3.16s
830:	learn: 0.3176868	total: 15.5s	remaining: 3.14s
831:	learn: 0.3175728	total: 15.5s	remaining: 3.13s
832:	learn: 0.3

977:	learn: 0.3016556	total: 19s	remaining: 428ms
978:	learn: 0.3015965	total: 19s	remaining: 409ms
979:	learn: 0.3014966	total: 19.1s	remaining: 389ms
980:	learn: 0.3014012	total: 19.1s	remaining: 370ms
981:	learn: 0.3013100	total: 19.1s	remaining: 350ms
982:	learn: 0.3012092	total: 19.1s	remaining: 331ms
983:	learn: 0.3011191	total: 19.2s	remaining: 312ms
984:	learn: 0.3010219	total: 19.2s	remaining: 292ms
985:	learn: 0.3008901	total: 19.2s	remaining: 273ms
986:	learn: 0.3007460	total: 19.2s	remaining: 253ms
987:	learn: 0.3006157	total: 19.3s	remaining: 234ms
988:	learn: 0.3005216	total: 19.3s	remaining: 215ms
989:	learn: 0.3004277	total: 19.3s	remaining: 195ms
990:	learn: 0.3003341	total: 19.3s	remaining: 176ms
991:	learn: 0.3002443	total: 19.4s	remaining: 156ms
992:	learn: 0.3001302	total: 19.4s	remaining: 137ms
993:	learn: 0.3000334	total: 19.4s	remaining: 117ms
994:	learn: 0.2999221	total: 19.4s	remaining: 97.7ms
995:	learn: 0.2997856	total: 19.5s	remaining: 78.2ms
996:	learn: 0.

<catboost.core.CatBoostClassifier at 0x1e9862487c8>

In [177]:
preds = model.predict(X_test)

In [178]:
from sklearn.metrics import accuracy_score

print(accuracy_score(preds, y_test))

0.83187651106565


In [179]:
model.predict_proba(X_test)

array([[0.97360241, 0.02639759],
       [0.97453644, 0.02546356],
       [0.96567002, 0.03432998],
       ...,
       [0.58090447, 0.41909553],
       [0.70876269, 0.29123731],
       [0.96317621, 0.03682379]])

In [180]:
from joblib import dump

In [181]:
dump(model, 'catboost_0.4.model')

['catboost_0.4.model']