# Import libraries and helpers

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys

In [None]:
sys.path.append('./helpers_models/')

In [None]:
sys.path.append('./data_visualization_and_augmentations/')

In [None]:
sys.path.append('../torch_videovision/')

In [None]:
sys.path.append('./important_csvs/')

In [None]:
from helpers_resnet import *

In [None]:
plt.rcParams['figure.figsize'] = (20,10)
font = {'family' : 'DejaVu Sans',  'weight' : 'normal',  'size'  : 20}
plt.rc('font', **font)

# Load Model, change head, freeze body

In [None]:
resnet = torchvision.models.resnet50(pretrained=True)
adaptive_pooling = AdaptiveConcatPool2d()
head = Head()
resnet.avgpool = adaptive_pooling
resnet.fc = head

os.environ['CUDA_VISIBLE_DEVICES']='0,1'

In [None]:
resnet = nn.DataParallel(resnet)
check_freeze(resnet.module)

#summary(resnet.module, torch.zeros(2,3,576,704).cuda())

tensor_transform = get_tensor_transform('ImageNet', True)
train_spat_transform = get_spatial_transform(2)
train_temp_transform = get_temporal_transform()
valid_spat_transform = get_spatial_transform(0)
valid_temp_transform = va.TemporalFit(size=16)

root_dir = '/media/scratch/astamoulakatos/nsea_video_jpegs/'
df = pd.read_csv('./small_dataset_csvs/events_with_number_of_frames_stratified.csv')
df_train = get_df(df, 20, True, False, False)
class_image_paths, end_idx = get_indices(df_train, root_dir)
train_loader = get_loader(1, 270, end_idx, class_image_paths, train_temp_transform, train_spat_transform, tensor_transform, False, True)
df_valid = get_df(df, 20, False, True, False)
class_image_paths, end_idx = get_indices(df_valid, root_dir)
valid_loader = get_loader(1, 270, end_idx, class_image_paths, valid_temp_transform, valid_spat_transform, tensor_transform, False, True)
df_test = get_df(df, 20, False, False, True)
class_image_paths, end_idx = get_indices(df_test, root_dir)
test_loader = get_loader(1, 270, end_idx, class_image_paths, valid_temp_transform, valid_spat_transform, tensor_transform, False, True)

torch.cuda.empty_cache()

lr = 1e-2
epochs = 10
optimizer = optim.AdamW(resnet.parameters(), lr=lr, weight_decay=1e-2)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs)

dataloaders = {
    "train": train_loader,
    "validation": valid_loader
}


In [None]:
summary(resnet, torch.zeros(2,3,576,704).cuda())

# Training loop

In [None]:
lr = 1e-1
optimizer = optim.Adam(resnet.parameters(), lr=lr, weight_decay=1e-2)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=6)
dataloaders = {
    "train": train_loader,
    "validation": valid_loader
}
save_model_path = '/media/raid/astamoulakatos/saved-resnet-models/'
device = torch.device('cuda')

In [None]:
lrate = scheduler.get_lr()

In [None]:
lrate

In [None]:
len(dataloaders['train'])

In [None]:
def train_model_yo(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=6):
    #liveloss = PlotLosses()
    model = model.to(device)
    val_loss = 100
    
    val_losses = []
    val_acc = []
    val_f1 = []
    train_losses = []
    train_acc = []
    train_f1 = []
    for epoch in range(num_epochs):
        logs = {}
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_acc = 0.0  
            running_f1 = 0.0
            #train_result = []
            for counter, (inputs, labels) in enumerate(Bar(dataloaders[phase])):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                preds = torch.sigmoid(outputs).data > 0.5
                preds = preds.to(torch.float32) 
                
                running_loss += loss.item() * inputs.size(0)
                running_acc += accuracy_score(labels.detach().cpu().numpy(), preds.cpu().detach().numpy()) *  inputs.size(0)
                running_f1 += f1_score(labels.detach().cpu().numpy(), (preds.detach().cpu().numpy()), average="samples")  *  inputs.size(0)
           
                if (counter!=0) and (counter%100==0):
                    if phase == 'train':
                        result = '  Training Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(running_loss/(inputs.size(0)*counter),
                                                                                         running_acc/(inputs.size(0)*counter),
                                                                                         running_f1/(inputs.size(0)*counter))
                        print(result)
                    if phase == 'validation':
                        result = '  Validation Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(running_loss/(inputs.size(0)*counter),
                                                                                         running_acc/(inputs.size(0)*counter),
                                                                                         running_f1/(inputs.size(0)*counter))
                        print(result)
                        
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_acc / len(dataloaders[phase].dataset)
            epoch_f1 = running_f1 / len(dataloaders[phase].dataset)
            
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_acc.append(epoch_acc)
                train_f1.append(epoch_f1)
            
            #prefix = ''
            if phase == 'validation':
                #prefix = 'val_'
                val_losses.append(epoch_loss)
                val_acc.append(epoch_acc)
                val_f1.append(epoch_f1)
                
                if epoch_loss < val_loss:
                    val_loss = epoch_loss
                    save_path = f'{save_model_path}/best-checkpoint-{str(epoch).zfill(3)}epoch.pth'
                    states = {  'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'val_loss': epoch_loss,
                                'epoch': epoch,  }
                    
                    torch.save(states, save_file_path)
                    for path in sorted(glob(f'{save_model_path}/best-checkpoint-*epoch.pth'))[:-3]:
                        os.remove(path)
                
#             logs[prefix + 'log loss'] = epoch_loss.item()
#             logs[prefix + 'accuracy'] = epoch_acc.item()
#             logs[prefix + 'f1_score'] = epoch_f1.item()
            
#         liveloss.update(logs)
#         liveloss.send()
        with open("resnet_val_losses.txt", "wb") as fp:   #Pickling
            pickle.dump(val_losses, fp)
        with open("resnet_val_acc.txt", "wb") as fp:   #Pickling
            pickle.dump(val_acc, fp)
        with open("resnet_val_f1.txt", "wb") as fp:   #Pickling
            pickle.dump(val_f1, fp)
        with open("resnet_train_losses.txt", "wb") as fp:   #Pickling
            pickle.dump(train_losses, fp)
        with open("resnet_train_acc.txt", "wb") as fp:   #Pickling
            pickle.dump(train_acc, fp)
        with open("resnet_train_f1.txt", "wb") as fp:   #Pickling
            pickle.dump(train_f1, fp)

In [None]:
train_model_yo(dataloaders, device, resnet, criterion, optimizer, scheduler, num_epochs=6)