In [None]:
import os
import gc
import time

from IPython.display import clear_output
from tqdm import tqdm
from tqdm.contrib import tzip

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
train_df = pd.read_csv("/kaggle/input/hms-harmful-brain-activity-classification/train.csv")

In [None]:
#Загружаю все спеки разом
dct = np.load('/kaggle/input/default-specs/spectograms.npy', allow_pickle=True).item()

In [None]:
#Делю по пациентам на трейн, валидацию и тест
patients = train_df.patient_id.unique()
patients_train, patients_val_test, _, _ = train_test_split(patients, np.arange(len(patients)), test_size=0.3, random_state=123)
patients_val, patients_test, _, _ = train_test_split(patients_val_test, np.arange(len(patients_val_test)), test_size=0.5, random_state=123)

In [None]:
train_patients = train_df.loc[train_df.patient_id.isin(patients_train)].copy()
val_patients = train_df.loc[train_df.patient_id.isin(patients_val)].copy()
test_patients = train_df.loc[train_df.patient_id.isin(patients_test)].copy()

In [None]:
# Функция, чтобы считать среднее и стандартное отклонение в цикле, потому мтодами numpy памяти не хватает
def online_mean_std(data):
    n = 0
    mean = 0
    M2 = 0

    for x in tqdm(data):
        n = n + 1
        x = np.nan_to_num(x)
        delta = x - mean
        mean = mean + delta/n
        M2 = M2 + delta*(x - mean)

    variance = M2/(n - 1)
    return np.sqrt(variance.mean()), mean.mean()

In [None]:
# Считаю средние и ст.отклонения для каждого типа спектограмм
means = []
stds = []
for el in ['LL', 'RL', 'LP', 'RP']:
    res = np.concatenate([dct[sid][el][None, :, int(slos)//2: int(slos)//2+300] for sid, slos in zip(train_patients.spectrogram_id, train_patients.spectrogram_label_offset_seconds)], axis=0)
    std, mean = online_mean_std(res)
    means.append(mean)
    stds.append(std)
    del res
    gc.collect()
    
norm_mean = np.array(means).reshape((4, 1, 1))
norm_std = np.array(stds).reshape((4, 1, 1))

### [None, :, int(slos)//2: int(slos)//2+300] Нужно, чтобы по 10 минут из спек вырезать. Новую ось создаю, чтобы по ней конкатить спеки

In [None]:
def normalize(x):
    '''[c, h, w]'''
    return (x - norm_mean) / norm_std

class SpecDataset(Dataset):
# В данных размер спеки 99 x 300    
    def __init__(self, df, dct, img_size=(99, 300)):
        self.df = df
        self.dct = dct
        self.image_size = img_size
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        spec = self.dct[self.df.iloc[index].spectrogram_id]
        shift = self.df.iloc[index].spectrogram_label_offset_seconds
        ll, rl, lp, rp = spec['LL'], spec['RL'], spec['LP'], spec['RP']
        x = np.concatenate([ll[None, :, int(shift)//2: int(shift)//2+300], rl[None, :, int(shift)//2: int(shift)//2+300], lp[None, :, int(shift)//2: int(shift)//2+300], rp[None, :, int(shift)//2: int(shift)//2+300]], axis=0)
        x = torch.from_numpy(normalize(x)).float()
        x = torch.nan_to_num(x, 0)
        transforms = Resize([self.image_size[0], self.image_size[1]])
        x = transforms(x)        
        return x

## Сеть представляет из себя просто ResNet блоки, которые уменьшают/увличивают ширину и высоту в два раза и увличивают/уменьшают число каналов в два раза

In [None]:
# class ResNetBlock(nn.Module):
#     def __init__(self, in_channels, kernel_size, img_size, modify=False, bn=True):
#         super().__init__()
#         self.modify = modify
#         if modify=='downsample':
#             self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, stride=2, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
#             self.conv2 = nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*2, kernel_size=kernel_size, padding=kernel_size//2,bias=False)
# #             [C, H, W]
#             if bn:
#                 self.bn1 = nn.LayerNorm([2*in_channels, img_size//2, img_size//2])
#                 self.bn2 = nn.LayerNorm([2*in_channels, img_size//2, img_size//2])
#             else:
#                 self.bn1 = nn.Identity()
#                 self.bn2 = nn.Identity()
                
#         elif modify=='upsample':
#             self.conv1 = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, stride=2, kernel_size=kernel_size, output_padding=1, padding=kernel_size//2, bias=False)
#             self.conv2 = nn.Conv2d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
#             self.bn1 = nn.LayerNorm([in_channels//2, img_size, img_size])
#             self.bn2 = nn.LayerNorm([in_channels//2, img_size, img_size])
#         else:
#             self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2)
#             self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2)
#             self.bn1 = nn.LayerNorm([in_channels, img_size, img_size])
#             self.bn2 = nn.LayerNorm([in_channels, img_size, img_size])
#         self.act = nn.ReLU()
        
#         if modify=='downsample':
#             self.proj = nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, stride=2, kernel_size=kernel_size, padding=kernel_size//2)
#         if modify=='upsample':
#             self.proj = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, stride=2, kernel_size=kernel_size, output_padding=1, padding=kernel_size//2)


#     def forward(self, x):
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.act(out)
#         out = self.conv2(out)
#         out = self.bn2(out)
#         if self.modify:
#             x = self.proj(x)
#         out = x + out
#         out = self.act(out)
#         return out

    
# class Encoder(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv = nn.Conv2d(4, 16, 7, 1, 7//2)
#         self.rnb1 = ResNetBlock(16, 3, 256, modify='downsample')
#         self.rnb2 = ResNetBlock(32, 3, 128, modify='downsample')
#         self.rnb3 = ResNetBlock(64, 3, 64, modify='downsample')
#         self.rnb4 = ResNetBlock(128, 3, 32, modify='downsample')
#         self.rnb5 = ResNetBlock(256, 3, 16, modify='downsample')
#         self.rnb6 = ResNetBlock(512, 3, 8, modify='downsample')
#         self.rnb7 = ResNetBlock(1024, 3, 4, modify='downsample')
#         self.rnb8 = ResNetBlock(2048, 3, 2, modify='downsample')
        
#     def forward(self, x):
#         x = self.conv(x)
#         x = self.rnb1(x)
#         x = self.rnb2(x)
#         x = self.rnb3(x)
#         x = self.rnb4(x)
#         x = self.rnb5(x)
#         x = self.rnb6(x)
#         x = self.rnb7(x)
#         x = self.rnb8(x)
#         return x
    
# class Decoder(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.rnb1 = ResNetBlock(4096, 3, 2, modify='upsample')
#         self.rnb2 = ResNetBlock(2048, 3, 4, modify='upsample')
#         self.rnb3 = ResNetBlock(1024, 3, 8, modify='upsample')
#         self.rnb4 = ResNetBlock(512, 3, 16, modify='upsample')
#         self.rnb5 = ResNetBlock(256, 3, 32, modify='upsample')
#         self.rnb6 = ResNetBlock(128, 3, 64, modify='upsample')
#         self.rnb7 = ResNetBlock(64, 3, 128, modify='upsample')
#         self.rnb8 = ResNetBlock(32, 3, 256, modify='upsample')
#         self.conv = nn.Conv2d(16, 4, 3, 1, 3//2)

#     def forward(self, x):
#         x = self.rnb1(x)
#         x = self.rnb2(x)
#         x = self.rnb3(x)
#         x = self.rnb4(x)
#         x = self.rnb5(x)
#         x = self.rnb6(x)
#         x = self.rnb7(x)
#         x = self.rnb8(x)
#         x = self.conv(x)
#         return x

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, kernel_size, modify=False, bn=True):
        super().__init__()
        self.modify = modify
        if modify=='downsample':
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, stride=2, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
            self.conv2 = nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*2, kernel_size=kernel_size, padding=kernel_size//2,bias=False)
            if bn:
                self.bn1 = nn.BatchNorm2d(in_channels*2)
                self.bn2 = nn.BatchNorm2d(in_channels*2)
            else:
                self.bn1 = nn.Identity()
                self.bn2 = nn.Identity()
                
        elif modify=='upsample':
            self.conv1 = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, stride=2, kernel_size=kernel_size, output_padding=1, padding=kernel_size//2, bias=False)
            self.conv2 = nn.Conv2d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
            self.bn1 = nn.BatchNorm2d(in_channels//2)
            self.bn2 = nn.BatchNorm2d(in_channels//2)
        else:
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2)
            self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2)
            self.bn1 = nn.BatchNorm2d(in_channels)
            self.bn2 = nn.BatchNorm2d(in_channels)
        self.act = nn.ReLU()
        
        if modify=='downsample':
            self.proj = nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, stride=2, kernel_size=kernel_size, padding=kernel_size//2)
        if modify=='upsample':
            self.proj = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, stride=2, kernel_size=kernel_size, output_padding=1, padding=kernel_size//2)


    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.modify:
            x = self.proj(x)
        out = x + out
        out = self.act(out)
        return out

    
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(4, 16, 7, 1, 7//2)
        self.rnb1 = ResNetBlock(16, 3, modify='downsample')
        self.rnb2 = ResNetBlock(32, 3, modify='downsample')
        self.rnb3 = ResNetBlock(64, 3, modify='downsample')
        self.rnb4 = ResNetBlock(128, 3, modify='downsample')
        self.rnb5 = ResNetBlock(256, 3, modify='downsample')
        self.rnb6 = ResNetBlock(512, 3, modify='downsample')
        self.rnb7 = ResNetBlock(1024, 3, modify='downsample')
        self.rnb8 = ResNetBlock(2048, 3, modify='downsample')
        
    def forward(self, x):
        x = self.conv(x)
        x = self.rnb1(x)
        x = self.rnb2(x)
        x = self.rnb3(x)
        x = self.rnb4(x)
        x = self.rnb5(x)
        x = self.rnb6(x)
        x = self.rnb7(x)
        x = self.rnb8(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnb1 = ResNetBlock(4096, 3, modify='upsample')
        self.rnb2 = ResNetBlock(2048, 3, modify='upsample')
        self.rnb3 = ResNetBlock(1024, 3, modify='upsample')
        self.rnb4 = ResNetBlock(512, 3, modify='upsample')
        self.rnb5 = ResNetBlock(256, 3, modify='upsample')
        self.rnb6 = ResNetBlock(128, 3, modify='upsample')
        self.rnb7 = ResNetBlock(64, 3, modify='upsample')
        self.rnb8 = ResNetBlock(32, 3, modify='upsample')
        self.conv = nn.Conv2d(16, 4, 3, 1, 3//2)

    def forward(self, x):
        x = self.rnb1(x)
        x = self.rnb2(x)
        x = self.rnb3(x)
        x = self.rnb4(x)
        x = self.rnb5(x)
        x = self.rnb6(x)
        x = self.rnb7(x)
        x = self.rnb8(x)
        x = self.conv(x)
        return x

In [None]:
class SimpleAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            Encoder(),
            Decoder()
        )
        
    def forward(self, x):
        return self.net(x)

In [None]:
def run_epoch(model, dataloader, loss_fn, optimizer, epoch, device, scaler):
    model = model.to(device)
    model.train()
    losses = []
    for batch in tqdm(dataloader, total=len(dataloader)):
        x = batch.to(device)
        
#         with torch.autocast(device_type='cuda' if device=='cuda' else 'cpu', dtype=torch.float16 if device=='cuda' else torch.bfloat16):
        x_recon = model(x)
        loss = loss_fn(x, x_recon)

        loss.backward()
        optimizer.step()
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.update()
        
        optimizer.zero_grad()
                
        losses.append(loss.detach().cpu().item())
#     print(f'Не нан значений во время train: {np.count_nonzero(~np.isnan(losses))}')
    return np.nanmean(losses)

In [None]:
def evaluate(model, dataloader, loss_fn, device, scaler):
    model = model.to(device)
    losses = []
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader, total=len(dataloader)):
            x = batch.to(device)

#             with torch.autocast(device_type='cuda' if device=='cuda' else 'cpu', dtype=torch.float16 if device=='cuda' else torch.bfloat16):
            x_recon = model(x)
            loss = loss_fn(x, x_recon)
#             scaler.scale(loss)
            losses.append(loss.detach().cpu().item())
#     print(f'Не нан значений во время eval: {np.count_nonzero(~np.isnan(losses))}')
    return np.nanmean(losses)

In [None]:
scaler = torch.cuda.amp.GradScaler()
def run_experiment(model, dataloader_train, dataloader_val, loss_fn, optimizer, num_epochs, device, stop_after=5, scaler=scaler):
    losses_train = []
    losses_val = []
    best_loss_val = np.inf
    c = 0
    total_runtime = 0
    for epoch in range(num_epochs):
        start = time.time()
        
        if c == stop_after:
            print(f'Обучение остановлено, так как лосс на валидации не падал {stop_after} эпох')
            break
        
        loss_train = run_epoch(model, dataloader_train, loss_fn, optimizer, epoch, device, scaler)
        loss_val = evaluate(model, dataloader_val, loss_fn, device, scaler)
        losses_train.append(loss_train)
        losses_val.append(loss_val)
        clear_output()
        if best_loss_val > loss_val:
            torch.save(model.state_dict(), 'best_model.pth')
            torch.save(optimizer, 'optimizer.pth')
            best_loss_val = loss_val
            c = 0
        else:
            c += 1
            
        print(f"epoch: {str(epoch).zfill(3)} | loss_train: {loss_train:5.5f} | loss_val: {loss_val:5.5f} | best_loss: {best_loss_val:5.5f}")
        
        plt.plot(losses_train, label='Loss train')
        plt.plot(losses_val, label='Loss val')
        plt.legend()
        plt.show()
        
        stop = time.time()
        runtime = stop - start
        total_runtime += runtime
        if 12*60*60 - 600 - total_runtime < runtime:
            break
        
    return losses_train, losses_val, model

In [None]:
dataset_train = SpecDataset(train_patients, dct, img_size=(256, 256))
dataset_val = SpecDataset(val_patients, dct, img_size=(256, 256))
dataset_test = SpecDataset(test_patients, dct, img_size=(256, 256))

dataloader_train = DataLoader(
    dataset=dataset_train,
    batch_size=128,
    shuffle=True,
    drop_last=True
)

dataloader_val = DataLoader(
    dataset=dataset_val,
    batch_size=128,
    shuffle=False,
    drop_last=False
)

dataloader_test = DataLoader(
    dataset=dataset_test,
    batch_size=128,
    shuffle=False,
    drop_last=False
)

In [None]:
# small, _ = torch.utils.data.random_split(dataset_train, [256, len(dataset_train) - 256])
# small_val, _ = torch.utils.data.random_split(dataset_train, [256, len(dataset_train) - 256])
# small_dataloader_train = DataLoader(
#     dataset=small,
#     batch_size=128,
#     shuffle=True,
#     drop_last=True
# )

# small_dataloader_val = DataLoader(
#     dataset=small_val,
#     batch_size=128,
#     shuffle=False,
#     drop_last=False
# )

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4
model = SimpleAE()
model= nn.DataParallel(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
num_epochs = 100

In [None]:
def init_weights(w):
    if isinstance(w, nn.Linear) or isinstance(w, nn.Conv2d) or isinstance(w, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(w.weight)
    
model.apply(init_weights)

In [None]:
model.load_state_dict(torch.load('/kaggle/input/autoencoder-weights/best_model.pth', map_location=torch.device(device)))
# optimizer = torch.load('/kaggle/input/autoencoder-weights/optimizer.pth', map_location=torch.device(device))

In [None]:
# def nan_hook(self, inp, output):
#     if not isinstance(output, tuple):
#         outputs = [output]
#     else:
#         outputs = output

#     for i, out in enumerate(outputs):
#         nan_mask = torch.isnan(out)
#         if nan_mask.any():
#             print("In", self.__class__.__name__)
#             raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])

# for submodule in model.modules():
#     submodule.register_forward_hook(nan_hook)

In [None]:
losses_train, losses_val, model = run_experiment(model, dataloader_train, dataloader_val, loss_fn, optimizer, num_epochs, device, stop_after=15)

In [None]:
plt.plot(losses_train, label='Loss train')
plt.plot(losses_val, label='Loss val')
plt.legend()

In [None]:
print('Test loss:', evaluate(model, dataloader_test, loss_fn, device, scaler))