In [None]:
import sys
sys.path.append('../scripts')

import yaml
import torch
from argparse import Namespace
from tqdm.notebook import tqdm

from runner import Runner

In [None]:
from metrics import LWLRAP


In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
with open('/home/jupyter/rfcx_submission/config/training_config.yaml', 'r') as f:
    args= yaml.load(f,  yaml.FullLoader)

In [None]:
args['epochs'] = 5

In [None]:
args= Namespace(**args)

In [None]:
runner= Runner(device, args)

In [None]:
runner.set_data_loader()

In [None]:
runner.set_model()

In [None]:
runner.args.training['upstream']

In [None]:
loss_function = torch.nn.BCELoss()

In [None]:
optimizer = torch.optim.Adam(runner.model.parameters(), lr=0.001, )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.4)

In [None]:
best_lwrap= 0

for i in range(args.epochs):
    
    train_loss = []
    train_lwlrap = []
    
    runner.model.train()
    if not runner.args.training['upstream']:
        runner.model.upstream.eval()
    
    for batch, (data, target) in tqdm(enumerate(runner.train_dataloader), total= len(runner.train_dataloader)):

            
        optimizer.zero_grad()
        
        output = runner.model(data)
        loss = loss_function(output, target)
        
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
        _score= LWLRAP(output, target)
        
        train_lwlrap.append(_score)
        print(f'loss : {loss.item()} score: {_score}')
    
    for g in optimizer.param_groups:
        
        lr = g['lr']
    print(f'Epoch : {i}  training end. LR: {lr}  Loss: {torch.mean(train_loss)}  lwrlrap_score: {torch.mean(train_lwlrap)}')
        
        
    with torch.no_grad():
        val_loss = []
        val_lwlrap = []
        
        model.eval()
        for batch, (data, target) in tqdm(enumerate(runner.eval_dataloader), total= len(runner.eval_dataloader)):

            
            output = runner.model(data)
            loss = loss_function(output, target)
            
        
            val_loss.append(loss.item())
                                        
            _score= LWLRAP(output, target)
            val_lwlrap.append(_score)
    
    print(f'Valid Loss: {torch.mean(val_loss)}  lwrlrap_score: {torch.mean(val_lwlrap)}')
    
    if torch.mean(val_lwlrap) > best_lwrap:
        torch.save(model, 'best_model_{i}.pt')
        best_lwrap = torch.mean(val_lwlrap)
        
    scheduler.step()

torch.save(model, 'best_model_{i}.pt')