# Import libraries and helpers

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys

In [None]:
sys.path.append('./helpers_models/')

In [None]:
sys.path.append('./data_visualization_and_augmentations/')

In [None]:
sys.path.append('../torch_videovision/')

In [None]:
sys.path.append('../video-classification/ResNetCRNN/')

In [None]:
sys.path.append('./important_csvs/')

In [None]:
from helpers_lstm import *

In [None]:
plt.rcParams['figure.figsize'] = (20,10)
font = {'family' : 'DejaVu Sans',  'weight' : 'normal',  'size'  : 20}
plt.rc('font', **font)

# Load training and validation sets

In [None]:
tensor_transform = get_tensor_transform('ImageNet')

In [None]:
train_transform = get_video_transform(2)
valid_transform = get_video_transform(0)

In [None]:
df = pd.read_csv('./important_csvs/events_with_number_of_frames_stratified.csv')
df = get_df(df, 16, False)

In [None]:
class_image_paths, end_idx = get_indices(df)

In [None]:
train_loader = get_loader(16, 4, end_idx, class_image_paths, train_transform, tensor_transform, True, False)

In [None]:
# show_batch(train_loader,4)

In [None]:
df = pd.read_csv('./important_csvs/events_with_number_of_frames_stratified.csv')
df = get_df(df, 16, True)

In [None]:
class_image_paths, end_idx = get_indices(df)

In [None]:
valid_loader = get_loader(16, 4, end_idx, class_image_paths, valid_transform, tensor_transform, True, False)

# Load Model, change head, freeze body

In [None]:
device = torch.device('cuda')

In [None]:
cnn_encoder = ResCNNEncoder().to(device)
adaptive_pool = AdaptiveConcatPool2d()
cnn_encoder.resnet[8] = adaptive_pool

In [None]:
for param in cnn_encoder.parameters():
    param.requires_grad = False
for param in cnn_encoder.resnet[8].parameters():
    param.requires_grad = True
for param in cnn_encoder.headbn1.parameters():
    param.requires_grad = True
for param in cnn_encoder.fc1.parameters():
    param.requires_grad = True

In [None]:
rnn_decoder = DecoderRNNattention(batch_size=4).to(device)
for param in rnn_decoder.parameters():
    param.requires_grad = True

In [None]:
crnn_params, cnn_encoder, rnn_decoder = parallelize_model(cnn_encoder, rnn_decoder)

## Torchsummary

In [None]:
model = nn.Sequential(cnn_encoder.module,rnn_decoder.module)

In [None]:
summary(model, torch.zeros(2,16,3,576,704).cuda())

In [None]:
summary(cnn_encoder.module, torch.zeros(2,16,3,576,704).cuda())

In [None]:
rnn_decoder.module

In [None]:
summary(rnn_decoder.module, torch.zeros(2,16,512).cuda()) #,  change the division by 2 to wok

## LR Finder

In [None]:
model = nn.Sequential(cnn_encoder,rnn_decoder) #!!!!!!!

In [None]:
model

In [None]:
torch.cuda.empty_cache()

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(crnn_params, lr=1e-7, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=100, num_iter=200)
lr_finder.plot() # to inspect the loss-learning rate graph
lr_finder.reset() # to reset the model and optimizer to their initial state

# Training loop

In [None]:
lr = 1e-1; lr

In [None]:
optimizer = optim.Adam(crnn_params, lr=lr, weight_decay=1e-2)

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=6)

In [None]:
torch.cuda.empty_cache()

In [None]:
dataloaders = {
    "train": train_loader,
    "validation": valid_loader
}

In [None]:
save_model_path = '/media/raid/astamoulakatos/saved-lstm-models/'

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=6):
    #liveloss = PlotLosses()
    model = model.to(device)
    val_loss = 100
    
    val_losses = []
    val_acc = []
    val_f1 = []
    train_losses = []
    train_acc = []
    train_f1 = []
    for epoch in range(num_epochs):
        logs = {}
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_acc = 0.0  
            running_f1 = 0.0
            #train_result = []

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                preds = torch.sigmoid(outputs).data > 0.5
                preds = preds.to(torch.float32) 
                
                running_loss += loss.item() * inputs.size(0)
                running_acc += accuracy_score(labels.detach().cpu().numpy(), preds.cpu().detach().numpy()) *  inputs.size(0)
                running_f1 += f1_score(labels.detach().cpu().numpy(), (preds.detach().cpu().numpy()), average="samples")  *  inputs.size(0)
           
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_acc / len(dataloaders[phase].dataset)
            epoch_f1 = running_f1 / len(dataloaders[phase].dataset)
            
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_acc.append(epoch_acc)
                train_f1.append(epoch_f1)
            
            #prefix = ''
            if phase == 'validation':
                #prefix = 'val_'
                val_losses.append(epoch_loss)
                val_acc.append(epoch_acc)
                val_f1.append(epoch_f1)
                
                if epoch_loss < val_loss:
                    val_loss = epoch_loss
                    save_path = f'{save_model_path}/best-checkpoint-{str(epoch).zfill(3)}epoch.pth'
                    states = {  'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'val_loss': epoch_loss,
                                'epoch': epoch,  }
                    
                    torch.save(states, save_file_path)
                    for path in sorted(glob(f'{save_model_path}/best-checkpoint-*epoch.pth'))[:-3]:
                        os.remove(path)
                
#             logs[prefix + 'log loss'] = epoch_loss.item()
#             logs[prefix + 'accuracy'] = epoch_acc.item()
#             logs[prefix + 'f1_score'] = epoch_f1.item()
            
#         liveloss.update(logs)
#         liveloss.send()
        with open("cnnlstm_val_losses.txt", "wb") as fp:   #Pickling
            pickle.dump(val_losses, fp)
        with open("cnnlstm_val_acc.txt", "wb") as fp:   #Pickling
            pickle.dump(val_acc, fp)
        with open("cnnlstm_val_f1.txt", "wb") as fp:   #Pickling
            pickle.dump(val_f1, fp)
        with open("cnnlstm_train_losses.txt", "wb") as fp:   #Pickling
            pickle.dump(train_losses, fp)
        with open("cnnlstm_train_acc.txt", "wb") as fp:   #Pickling
            pickle.dump(train_acc, fp)
        with open("cnnlstm_train_f1.txt", "wb") as fp:   #Pickling
            pickle.dump(train_f1, fp)

In [None]:
train_model(model, criterion, optimizer, scheduler, num_epochs=6)

# Load saved model for more training