In [1]:
!pip install -q torch_geometric rdkit torchmetrics

[0m

In [2]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import random
import itertools

# RDkit
from rdkit import Chem, RDLogger
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem import PandasTools

# Pytorch and Pytorch Geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader

from torch_geometric.nn import GAE, GCNConv, global_mean_pool

import torchmetrics
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split



In [3]:
lr = 2e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 256
epochs = 20

In [4]:
desc = pd.read_csv('/kaggle/input/ddi-france/nodes.tsv', sep='\t')
ddi = pd.read_csv('/kaggle/input/ddi-france/interactions.tsv', sep='\t')
desc.loc[desc['DATABASE_ID']=='DB00526', 'SMILES'] = 'O1C(=O)C(=O)O[Pt-2]12[NH2+]C0CCCCC0[NH2+]2'
desc.loc[desc['DATABASE_ID']=='DB13145', 'SMILES'] = 'O=C1O[Pt-2]([NH3+])([NH3+])OC1'
desc.loc[desc['DATABASE_ID']=='DB00515', 'SMILES'] = '[NH3+]-[Pt-2](Cl)(Cl)[NH3+]'
desc.loc[desc['DATABASE_ID']=='DB00958', 'SMILES'] = 'C1CC2(C1)C(=O)O[Pt-2]([NH3+])([NH3+])OC2=O'
desc.loc[desc['DATABASE_ID']=='DB01999', 'SMILES'] = 'C1=CC(=CC=C1C2=C3C=CC(=C(C4=CC=C([NH]4)C(=C5C=CC(=N5)C(=C6C=CC2=N6)C7=CC=C(C=C7)S(=O)(=O)O)C8=CC=C(C=C8)S(=O)(=O)O)C9=CC=C(C=C9)S(=O)(=O)O)[NH]3)S(=O)(=O)O'
desc.loc[desc['DATABASE_ID']=='DB11630', 'SMILES'] = 'C1CC2=NC1=C(C3=CC=C(N3)C(=C4C=CC(=N4)C(=C5C=CC(=C2C6=CC(=CC=C6)O)N5)C7=CC(=CC=C7)O)C8=CC(=CC=C8)O)C9=CC(=CC=C9)O'
desc

Unnamed: 0,DATABASE_ID,SMILES,MOLECULAR_WEIGHT,JCHEM_AVERAGE_POLARIZABILITY,JCHEM_BIOAVAILABILITY,JCHEM_DONOR_COUNT,JCHEM_FORMAL_CHARGE,JCHEM_NEUTRAL_CHARGE,JCHEM_NUMBER_OF_RINGS,JCHEM_PHYSIOLOGICAL_CHARGE,JCHEM_PKA,JCHEM_PKA_STRONGEST_ACIDIC,JCHEM_PKA_STRONGEST_BASIC,JCHEM_POLAR_SURFACE_AREA,JCHEM_REFRACTIVITY,JCHEM_ROTATABLE_BOND_COUNT
0,DB00006,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...,2180.2853,218.543031,0.0,28.0,0.0,-4.0,6.0,-4.0,3.167211,2.784541,11.878407,901.57,543.3342,66.0
1,DB00007,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...,1209.3983,125.237464,0.0,16.0,0.0,1.0,6.0,1.0,10.979074,9.489203,11.918319,429.04,327.2417,32.0
2,DB00014,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...,1269.4105,130.735742,0.0,17.0,0.0,2.0,6.0,1.0,10.004730,9.357765,10.911999,495.89,325.8388,33.0
3,DB00027,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...,1811.2530,194.731121,0.0,20.0,0.0,,8.0,0.0,11.945820,11.560888,,519.89,492.3329,50.0
4,DB00035,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...,1069.2200,104.780422,0.0,14.0,0.0,,4.0,1.0,11.344259,9.496981,11.771940,435.41,279.7799,19.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11570,DB17379,CC(C)C1=C(O)C(O)=C(C=O)C2=C(O)C(=C(C)C=C12)C1=...,518.5544,55.980570,0.0,6.0,0.0,0.0,4.0,-1.0,8.319379,7.799882,-6.149501,155.52,147.6120,4.0
11571,DB17383,CN1CCN(CC2=CC=C(NC(=O)C3=NNC=C3NC3=C4C=CNC4=NC...,431.5040,,,,,,,,,,,,,
11572,DB17384,CC1=C2N=C(C3=CC=CC=C3Cl)C3=C(NC2=NN1)C=C(N=C3)...,394.8600,,,,,,,,,,,,,
11573,DB17385,CC[C@@]1(OC(=O)C(C)ON=C2C3=C(C4=C2C=C(C=C4[N+]...,850.7100,,,,,,,,,,,,,


In [5]:
ddi

Unnamed: 0,source,dst
0,DB00006,DB06605
1,DB00006,DB06695
2,DB00006,DB01254
3,DB00006,DB01609
4,DB00006,DB01586
...,...,...
2369806,DB17083,DB13954
2369807,DB17083,DB13955
2369808,DB17083,DB13956
2369809,DB17083,DB14055


In [6]:
desc = desc[desc.columns[~(desc.isnull().sum()>12)]].fillna(0)
features = MinMaxScaler().fit_transform(desc[desc.columns[2:]])
desc[desc.columns[2:]] = features
desc

Unnamed: 0,DATABASE_ID,SMILES,MOLECULAR_WEIGHT,JCHEM_AVERAGE_POLARIZABILITY,JCHEM_BIOAVAILABILITY,JCHEM_DONOR_COUNT,JCHEM_FORMAL_CHARGE,JCHEM_NUMBER_OF_RINGS,JCHEM_PHYSIOLOGICAL_CHARGE,JCHEM_POLAR_SURFACE_AREA,JCHEM_REFRACTIVITY,JCHEM_ROTATABLE_BOND_COUNT
0,DB00006,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...,0.356233,0.493477,0.0,0.424242,0.7,0.171429,0.333333,0.474458,0.463277,0.442953
1,DB00007,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...,0.197455,0.282790,0.0,0.242424,0.7,0.171429,0.541667,0.225786,0.279025,0.214765
2,DB00014,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...,0.207269,0.295205,0.0,0.257576,0.7,0.171429,0.541667,0.260966,0.277828,0.221477
3,DB00027,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...,0.295882,0.439709,0.0,0.303030,0.7,0.228571,0.500000,0.273596,0.419791,0.335570
4,DB00035,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...,0.174530,0.236597,0.0,0.212121,0.7,0.114286,0.541667,0.229138,0.238556,0.127517
...,...,...,...,...,...,...,...,...,...,...,...,...
11570,DB17379,CC(C)C1=C(O)C(O)=C(C=O)C2=C(O)C(=C(C)C=C12)C1=...,0.084475,0.126406,0.0,0.090909,0.7,0.114286,0.458333,0.081844,0.125862,0.026846
11571,DB17383,CN1CCN(CC2=CC=C(NC(=O)C3=NNC=C3NC3=C4C=CNC4=NC...,0.070238,0.000000,0.0,0.000000,0.7,0.000000,0.500000,0.000000,0.000000,0.000000
11572,DB17384,CC1=C2N=C(C3=CC=CC=C3Cl)C3=C(NC2=NN1)C=C(N=C3)...,0.064246,0.000000,0.0,0.000000,0.7,0.000000,0.500000,0.000000,0.000000,0.000000
11573,DB17385,CC[C@@]1(OC(=O)C(C)ON=C2C3=C(C4=C2C=C(C=C4[N+]...,0.138795,0.000000,0.0,0.000000,0.7,0.000000,0.500000,0.000000,0.000000,0.000000


In [7]:
x_map = {
    'atomic_num': list(range(0, 119)),
    'chirality': ['CHI_UNSPECIFIED','CHI_TETRAHEDRAL_CW','CHI_TETRAHEDRAL_CCW','CHI_OTHER','CHI_TETRAHEDRAL','CHI_ALLENE','CHI_SQUAREPLANAR','CHI_TRIGONALBIPYRAMIDAL','CHI_OCTAHEDRAL',],
    'degree':list(range(0, 11)),
    'formal_charge':list(range(-5, 7)),
    'num_hs':list(range(0, 9)),
    'num_radical_electrons':list(range(0, 5)),
    'hybridization': ['UNSPECIFIED','S','SP','SP2','SP3','SP3D','SP3D2','OTHER',],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': ['UNSPECIFIED','SINGLE','DOUBLE','TRIPLE','QUADRUPLE','QUINTUPLE','HEXTUPLE','ONEANDAHALF','TWOANDAHALF','THREEANDAHALF','FOURANDAHALF','FIVEANDAHALF','AROMATIC','IONIC','HYDROGEN','THREECENTER','DATIVEONE','DATIVE','DATIVEL','DATIVER','OTHER','ZERO',],
    'stereo': ['STEREONONE','STEREOANY','STEREOZ','STEREOE','STEREOCIS','STEREOTRANS',],
    'is_conjugated': [False, True],
}


def from_smiles(smiles: str = None, mol: Chem.rdchem.Mol = None, with_hydrogen: bool = True, kekulize: bool = False) -> 'Data':
    r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
    instance.

    Args:
        smiles (str): The SMILES string.
        with_hydrogen (bool, optional): If set to :obj:`True`, will store
            hydrogens in the molecule graph. (default: :obj:`False`)
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    
    RDLogger.DisableLog('rdApp.*')

    if smiles and mol is None:
        mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        mol = Chem.MolFromSmiles('')
    if with_hydrogen:
        mol = Chem.AddHs(mol)
    if kekulize:
        Chem.Kekulize(mol)

    xs = []
    for atom in mol.GetAtoms():
        x = []
        x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        x.append(x_map['degree'].index(atom.GetTotalDegree()))
        x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        x.append(x_map['num_radical_electrons'].index(atom.GetNumRadicalElectrons()))
        x.append(x_map['hybridization'].index(str(atom.GetHybridization())))
        x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        x.append(x_map['is_in_ring'].index(atom.IsInRing()))
        xs.append(x)

    x = torch.tensor(xs, dtype=torch.float32).view(-1, 9)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        e.append(e_map['bond_type'].index(str(bond.GetBondType())))
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

    if edge_index.numel() > 0:  # Sort indices.
        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)

In [8]:
mol_list = {}
for i, row in tqdm(desc.iterrows(), total=len(desc)):
    try:
        mol_list[row['DATABASE_ID']] = [from_smiles(row['SMILES']), torch.tensor(np.float32(row.values[2:]))]
        if 0 in from_smiles(row['SMILES']).x.shape:
            print(row.values)
    except:
        pass

  0%|          | 0/11575 [00:00<?, ?it/s]

In [9]:
class DDIDataset(Dataset):
    def __init__(self, ddi_df, mol_list):
        super().__init__()
        self.ddi_df = ddi_df
        self.mol_list = mol_list
    
    def __len__(self):
        return len(self.ddi_df)
    
    def __getitem__(self, idx):
        x = self.ddi_df.iloc[idx]['source']
        y = self.ddi_df.iloc[idx]['dst']
        target = self.ddi_df.iloc[idx]['DDI']
        
        return *self.mol_list[x], *self.mol_list[y], float(target)

In [10]:
ddi1, ddi2 = train_test_split(list(set(ddi['source'].unique().tolist() + ddi['dst'].unique().tolist())), test_size=0.15, random_state=42)

In [11]:
train_ddi = ddi.loc[ddi['dst'].isin(ddi1) & ddi['source'].isin(ddi1)]
test_ddi = ddi.drop(train_ddi.index)

In [12]:
ddi_df = train_ddi
total_set = sorted(list(set(ddi_df['source'].values.tolist() + ddi_df['dst'].values.tolist()) ))
neg = random.choices(list(set(itertools.combinations(total_set, 2)) - set(ddi_df.itertuples(index=False, name=None))), k = 2*len(ddi_df))
ddi_df['DDI'] = [1]*len(ddi_df)
neg_df = pd.DataFrame(neg, columns=['source', 'dst'])
neg_df['DDI'] = [0]*len(neg_df)
train_ddi = pd.concat([ddi_df, pd.DataFrame({'source': ddi_df['dst'], 'dst': ddi_df['source'], 'DDI': ddi_df['DDI']}), neg_df])
train_ddi = train_ddi.sample(frac=1, ignore_index=True)

ddi_df = test_ddi
total_set = sorted(list(set(ddi_df['source'].values.tolist() + ddi_df['dst'].values.tolist()) ))
neg = random.choices(list(set(itertools.product(total_set, total_set)) - set(ddi_df.itertuples(index=False, name=None))), k = 2*len(ddi_df))
ddi_df['DDI'] = [1]*len(ddi_df)
neg_df = pd.DataFrame(neg, columns=['source', 'dst'])
neg_df['DDI'] = [0]*len(neg_df)
test_ddi = pd.concat([ddi_df, pd.DataFrame({'source': ddi_df['dst'], 'dst': ddi_df['source'], 'DDI': ddi_df['DDI']}), neg_df])
test_ddi = test_ddi.sample(frac=1, ignore_index=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ddi_df['DDI'] = [1]*len(ddi_df)


In [13]:
train_dataset = DDIDataset(train_ddi, mol_list)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataset = DDIDataset(test_ddi, mol_list)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [14]:
class GCN(nn.Module):
    def __init__(self, input_features):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_features, 8)
        self.conv2 = GCNConv(8, 8)
#         self.conv3 = GCNConv(8, 4)
        self.linear = nn.Linear(8, 8)
        self.bn = nn.BatchNorm1d(8)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
#         x = x.relu()
#         x = self.conv3(x, edge_index)
        x = self.linear(x)
        x = F.relu(x)
        x = self.bn(x)
        return x

class Model(nn.Module):
    def __init__(self, encoder):
        super(Model, self).__init__()
        self.encoder = encoder
        self.linear1 = nn.Linear(128+10, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(256, 64)
        self.bn3 = nn.BatchNorm1d(64)
        self.linear3 = nn.Linear(64, 1)

    def forward(self, x, x_feat, y, y_feat):
        batch = x.batch
        x = self.encoder(x.x, x.edge_index, batch)
        x = global_mean_pool(x, batch)
        x = F.adaptive_avg_pool1d(x, (128))
        x = torch.concat([x, x_feat], dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.bn2(x)

        batch = y.batch        
        y = self.encoder(y.x, y.edge_index, batch)
        y = global_mean_pool(y, batch)
        y = F.adaptive_avg_pool1d(y, (128))
        y = torch.concat([y, y_feat], dim=1)
        y = self.linear1(y)
        y = F.relu(y)
        y = self.bn2(y)

        z = torch.concat([x, y], dim=1)
        z = self.linear2(z)
        z = F.relu(z)
        z = self.bn3(z)
        z = self.linear3(z)
        return z

In [15]:
model = Model(GCN(9)).to(device)
print(model)

Model(
  (encoder): GCN(
    (conv1): GCNConv(9, 8)
    (conv2): GCNConv(8, 8)
    (linear): Linear(in_features=8, out_features=8, bias=True)
    (bn): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (linear1): Linear(in_features=138, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=256, out_features=64, bias=True)
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=64, out_features=1, bias=True)
)


In [16]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()
metric = torchmetrics.Accuracy('binary').to(device)

In [17]:
max_iter = 1000

for epoch in range(1, epochs + 1):
    model.train()
    optimizer.zero_grad()
    loss_list = []
    print(f'Epoch: {epoch:03d}')
    for i, data in enumerate(tqdm(train_dataloader, total=max_iter)):
        data = list(map(lambda x: x.to(device), data))
        outputs = model(*data[:-1])
        loss = criterion(outputs.flatten(), data[-1])
        loss.backward()
        optimizer.step()
        loss_list.append(loss)
        metric.update(torch.sigmoid(outputs.flatten()), data[-1])
        if i == max_iter:
            break
    print(f'TRAIN: Loss: {torch.mean(torch.tensor(loss_list))}, Accuracy: {metric.compute()}')
    metric.reset()

    model.eval()
    loss_list = []
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dataloader, total=max_iter)):
            data = list(map(lambda x: x.to(device), data))
            outputs = model(*data[:-1])
            loss = criterion(outputs.flatten(), data[-1])
            loss_list.append(loss)
            metric.update(torch.sigmoid(outputs.flatten()), data[-1])
            if i == max_iter:
                break
    print(f'TEST: Loss: {torch.mean(torch.tensor(loss_list))}, Accuracy: {metric.compute()}\n')
    metric.reset()

Epoch: 001


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6600868701934814, Accuracy: 0.6049848794937134


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6711218953132629, Accuracy: 0.5952094793319702

Epoch: 002


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6516530513763428, Accuracy: 0.617331862449646


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.668521523475647, Accuracy: 0.5958728790283203

Epoch: 003


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6461334228515625, Accuracy: 0.6217571496963501


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6639199256896973, Accuracy: 0.6055936217308044

Epoch: 004


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6399983763694763, Accuracy: 0.6289413571357727


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6676169633865356, Accuracy: 0.5952523946762085

Epoch: 005


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6377419233322144, Accuracy: 0.6321998238563538


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6630582213401794, Accuracy: 0.6128870844841003

Epoch: 006


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6316570043563843, Accuracy: 0.6384513974189758


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6692550778388977, Accuracy: 0.607416033744812

Epoch: 007


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6301103830337524, Accuracy: 0.6402308344841003


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6603012681007385, Accuracy: 0.6070687174797058

Epoch: 008


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6254210472106934, Accuracy: 0.6455146670341492


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6617352962493896, Accuracy: 0.6086569428443909

Epoch: 009


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6240130066871643, Accuracy: 0.6461390256881714


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.657844603061676, Accuracy: 0.611587643623352

Epoch: 010


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6202420592308044, Accuracy: 0.6514188647270203


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6613965034484863, Accuracy: 0.6133397817611694

Epoch: 011


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6168770790100098, Accuracy: 0.6553758978843689


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6630293130874634, Accuracy: 0.6121378540992737

Epoch: 012


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6139726042747498, Accuracy: 0.6581582427024841


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6626190543174744, Accuracy: 0.6089613437652588

Epoch: 013


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6117029786109924, Accuracy: 0.6612176895141602


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6537297964096069, Accuracy: 0.620457649230957

Epoch: 014


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6087195873260498, Accuracy: 0.6638127565383911


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.655787467956543, Accuracy: 0.6153026819229126

Epoch: 015


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.605064332485199, Accuracy: 0.6679180264472961


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6572195887565613, Accuracy: 0.6149982810020447

Epoch: 016


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6032602787017822, Accuracy: 0.6696935892105103


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6590597033500671, Accuracy: 0.6187601685523987

Epoch: 017


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.6011210680007935, Accuracy: 0.6718008518218994


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6590617299079895, Accuracy: 0.6119544506072998

Epoch: 018


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.5978927612304688, Accuracy: 0.6757968664169312


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6571366786956787, Accuracy: 0.6212420463562012

Epoch: 019


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.5977434515953064, Accuracy: 0.6746222376823425


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.6593713760375977, Accuracy: 0.6163601875305176

Epoch: 020


  0%|          | 0/1000 [00:00<?, ?it/s]

TRAIN: Loss: 0.5934473872184753, Accuracy: 0.6785402297973633


  0%|          | 0/1000 [00:00<?, ?it/s]

TEST: Loss: 0.656201183795929, Accuracy: 0.6192479133605957

