In [1]:
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

In [2]:
train_path = 'CNN_10ktokens/train'
test_path = 'CNN_10ktokens/test'
valid_path = 'CNN_10ktokens/validation'

In [3]:
img_height = 288
img_width = 432 

In [4]:
train_dataset = ImageFolder(train_path,transform = transforms.Compose([
    transforms.Resize((150,150)),transforms.ToTensor()
]))
val_dataset = ImageFolder(valid_path,transform = transforms.Compose([
    transforms.Resize((150,150)),transforms.ToTensor()
]))
test_dataset = ImageFolder(test_path,transform = transforms.Compose([
    transforms.Resize((150,150)),transforms.ToTensor()
]))

In [5]:
from torch.utils.data.dataloader import DataLoader
batch_size=16
train_dl = DataLoader(train_dataset, batch_size, shuffle = True, num_workers = 4, pin_memory = True)
val_dl = DataLoader(val_dataset, batch_size, num_workers = 4, pin_memory = True)
test_dl = DataLoader(test_dataset, batch_size, num_workers = 4, pin_memory = True)

In [83]:
import torch.nn as nn
import torch.nn.functional as F

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def confusion(prediction, truth):
    """ Returns the confusion matrix for the values in the `prediction` and `truth`
    tensors, i.e. the amount of positions where the values of `prediction`
    and `truth` are
    - 1 and 1 (True Positive)
    - 1 and 0 (False Positive)
    - 0 and 0 (True Negative)
    - 0 and 1 (False Negative)
    """

    confusion_vector = prediction / truth
    # Element-wise division of the 2 tensors returns a new tensor which holds a
    # unique value for each case:
    #   1     where prediction and truth are 1 (True Positive)
    #   inf   where prediction is 1 and truth is 0 (False Positive)
    #   nan   where prediction and truth are 0 (True Negative)
    #   0     where prediction is 0 and truth is 1 (False Negative)

    true_positives = torch.sum(confusion_vector == 1)
    false_positives = torch.sum(confusion_vector == float('inf'))
    true_negatives = torch.sum(torch.isnan(confusion_vector))
    false_negatives = torch.sum(confusion_vector == 0)

    return true_positives, false_positives, true_negatives, false_negatives

class ImageClassificationBase(nn.Module):
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        tp,fp,tn,fn = confusion(torch.argmax(out,1),labels)
        return {'val_loss': loss.detach(), 'val_acc': acc,'True Positive':tp,'False Positive':fp,'True Negative':tn,'False Negative':fn}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        batch_tp = [x['True Positive'] for x in outputs]
        epoch_tp = torch.stack(batch_tp).sum()
        batch_fp = [x['False Positive'] for x in outputs]
        epoch_fp = torch.stack(batch_fp).sum()
        batch_tn = [x['True Negative'] for x in outputs]
        epoch_tn = torch.stack(batch_tn).sum()
        batch_fn = [x['False Negative'] for x in outputs]
        epoch_fn = torch.stack(batch_fn).sum()
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item(),'True Positive':epoch_tp.item(),'False Positive':epoch_fp.item(),'True Negative':epoch_tn.item(),'False Negative':epoch_fn.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

In [84]:
class malwareClassification(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(3, 32, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(32,64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(360000,2)
        )
    
    def forward(self, xb):
        return self.network(xb)

In [None]:
def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.SGD):
    
    history = []
    optimizer = opt_func(model.parameters(),lr)
    for epoch in range(epochs):
        
        model.train()
        train_losses = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
        torch.save(model.state_dict(), 'Saved_models/'+'malware_class_10k_tok_'+ str(result['val_acc'])+'.pth')
    
    return history

In [None]:
num_epochs = 50
opt_func = torch.optim.Adam
lr = 0.001
model = malwareClassification()

In [None]:
history = fit(30, lr, model, train_dl, val_dl, opt_func)

In [85]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [86]:
from sklearn.metrics import confusion_matrix

@torch.no_grad()
def eval_model(model,dataset):
    model.eval()
    y_labels = []
    y_preds = []
    
    for images,labels in dataset:
        out = model(images)
        preds = torch.argmax(out,1).data.cpu().numpy()
        y_preds.extend(preds)
        labels = labels.data.cpu().numpy()
        y_labels.extend(labels)
        
    cf_matrix = confusion_matrix(y_labels, y_preds)
    return cf_matrix

def calculate_metrics(model_path,dataset):
    model = malwareClassification()
    model.load_state_dict(torch.load(model_path))
    test_result = evaluate(model,dataset)
#     print(test_result['val_acc'])
    return test_result

In [87]:
# calculate_metrics('malware_class_23feb_8109.pth',test_dl)

{'val_loss': 4.0755462646484375,
 'val_acc': 0.8125,
 'True Positive': 79,
 'False Positive': 24,
 'True Negative': 107,
 'False Negative': 21}