# Importing modules 


In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import numpy as np
import random

from model import Model
from dataset import Dataset
import copy
import argparse
import os
 



### Option.py

In [None]:
parser = argparse.ArgumentParser(description='CMA_XD_VioDet')
parser.add_argument('--rgb-list', default='list/rgb_new.list', help='list of rgb features ')
parser.add_argument('--flow-list', default='list/flow_new.list', help='list of flow features')
parser.add_argument('--audio-list', default='list/audio_new.list', help='list of audio features')
parser.add_argument('--test-rgb-list', default='list/rgb_test_new.list', help='list of test rgb features ')
parser.add_argument('--test-flow-list', default='list/flow_test_new.list', help='list of test flow features')
parser.add_argument('--test-audio-list', default='list/audio_test_new.list', help='list of test audio features')
parser.add_argument('--dataset-name', default='XD-Violence', help='dataset to train on XD-Violence')
parser.add_argument('--gt', default='list/gt_new.npy', help='file of ground truth ')


parser.add_argument('--modality', default='MIX_ALL', help='the type of the input, AUDIO,RGB,FLOW, MIX1, MIX2, '
                                                          'or MIX3, MIX_ALL')
parser.add_argument('--lr', type=float, default=0.0005, help='learning rate (default: 0.0005)')
parser.add_argument('--batch-size', type=int, default=128, help='number of instances in a batch of data')
parser.add_argument('--workers', default=8, help='number of workers in dataloader')
parser.add_argument('--model-name', default='new_model__MIXALL', help='name to save model')
parser.add_argument('--pretrained-ckpt', default=None, help='ckpt for pretrained model')
parser.add_argument('--feature-size', type=int, default=1024+128, help='size of feature (default: 2048)')
parser.add_argument('--num-classes', type=int, default=1, help='number of class')
parser.add_argument('--max-seqlen', type=int, default=200, help='maximum sequence length during training')
parser.add_argument('--max-epoch', type=int, default=50, help='maximum iteration to train (default: 50)')
parser.add_argument('--seed', type=int, default=9, help='Random Initiation (default: 9)')


_StoreAction(option_strings=['--seed'], dest='seed', nargs=None, const=None, default=9, type=<class 'int'>, choices=None, required=False, help='Random Initiation (default: 9)', metavar=None)

### Initializing the datasets

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
args = parser.parse_args("")
print(args)
setup_seed(args.seed)


train_data = Dataset(args, test_mode=False)

train_loader = DataLoader(train_data,
                            batch_size=args.batch_size, shuffle=True,
                            num_workers=args.workers, pin_memory=True)
print(len(train_loader.sampler))

del train_data

test_data = Dataset(args, test_mode=True)

test_loader = DataLoader(test_data,
                            batch_size=5, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

del test_data

### Setting the parameters

In [None]:
args = parser.parse_args("")
model = Model(args)
if torch.cuda.is_available():
    model.cuda()


In [None]:

if not os.path.exists('./ckpt'):
    os.makedirs('./ckpt')

criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60, eta_min=0)


In [None]:
## load weight of prev batch
pretrained_dict = torch.load('ckpt/xd_a2v__.pkl',map_location="cpu")
model_dict = model.state_dict()
if list(pretrained_dict.keys())[0].startswith("module."):
    pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
else:
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

### Training function

In [None]:
## model train
import torch
from loss import CLAS


def train(dataloader, model, optimizer, criterion):
    t_loss = 0.0
    with torch.set_grad_enabled(True):
        model.train()
        for i, (inputs, label) in enumerate(dataloader):
            seq_len = torch.sum(torch.max(torch.abs(inputs), dim=2)[0] > 0, 1)
            inputs = inputs[:, :torch.max(seq_len), :]
            inputs = inputs.float().cuda(non_blocking=True)
            label = label.float().cuda(non_blocking=True)
            logits = model(inputs)
            loss = CLAS(logits, label, seq_len, criterion)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            t_loss+=(loss.item()*len(dataloader))

    return t_loss/len(dataloader.sampler)



### Testing function

In [None]:
from sklearn.metrics import auc, precision_recall_curve
import numpy as np
import torch


def test(dataloader, model, gt):
    with torch.no_grad():
        print(torch.cuda.is_available())
        model.eval()
        pred = torch.zeros(0).cuda()

        for i, inputs in enumerate(dataloader):
            inputs = inputs.cuda()
            logits = model(inputs)
            logits = torch.mean(logits, 0)
            pred = torch.cat((pred, logits))
        pred = list(pred.cpu().detach().numpy())
        
        precision, recall, th = precision_recall_curve(list(gt), np.repeat(pred, 16))
        pr_auc = auc(recall, precision)
        return pr_auc




### Running the model

In [None]:
best_model_wts = copy.deepcopy(model.state_dict())
best_ap = 0.0
is_topk = True
gt = np.load('list/gt_new.npy')
st = time.time()
for epoch in range(50):
    cls_loss = train(train_loader, model, optimizer, criterion)
    scheduler.step()
    ap = test(test_loader, model, gt)
    if ap > best_ap:
        best_ap = ap
        best_model_wts = copy.deepcopy(model.state_dict())
    print('[Epoch {}/{}]: cls loss: {} | epoch AP: {:.4f}'.format(epoch + 1, 50, cls_loss, ap))
    del cls_loss,ap
print("The best accuracy is", best_ap)   

### Saving the model

In [None]:
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), './ckpt/' + args.model_name  + '.pkl')

### Loading pretrained model

In [None]:

pretrained_dict = torch.load('ckpt/xd_a2v__.pkl',map_location="cpu")
model_dict = model.state_dict()
if list(pretrained_dict.keys())[0].startswith("module."):
    pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
else:
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)


### Testing the model

In [None]:
st = time.time()
pr_auc = test(test_loader, model, gt)
time_elapsed = time.time() - st
print('test AP: {:.4f}\n'.format(pr_auc))