In [1]:
# target parameter attack on linear regression with close form solution on the cross derivative

import os
import time
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from numpy import linalg as LA
import numpy as np
import math
from tqdm import tqdm
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib
from sklearn.datasets import make_classification


torch.manual_seed(0)
device = 'cuda'

In [2]:
# creating the gaussian dataset

# define training set
separable = False
while not separable:
    samples = make_classification(n_samples=1000, n_features=3, n_redundant=0, n_informative=1, n_clusters_per_class=1, flip_y=-1)
    red = samples[0][samples[1] == 0]
    blue = samples[0][samples[1] == 1]
    separable = any([red[:, k].max() < blue[:, k].min() or red[:, k].min() > blue[:, k].max() for k in range(2)])
red_labels = np.zeros(len(red))
blue_labels = np.ones(len(blue))

labels = np.append(red_labels,blue_labels)
inputs = np.concatenate((red,blue),axis=0)

X_train, X_test, y_train,  y_test = train_test_split(
    inputs, labels, test_size=0.33, random_state=42)

X_train, X_test = torch.Tensor(X_train),torch.Tensor(X_test)
y_train, y_test = torch.Tensor(y_train),torch.Tensor(y_test)



In [3]:
print(y_train.size())
y = y_train.repeat(3)
print(y.size())

torch.Size([670])
torch.Size([2010])


In [4]:
batch_size_train= len(X_train)
batch_size_test = len(X_test)
class LinearDataset(Dataset):
    def __init__(self, X, y):
        assert X.size()[0] == y.size()[0]
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.size()[0]

    def __getitem__(self, idx):
        
        
        return [self.X[idx], self.y[idx]]
train_loader = DataLoader(LinearDataset(X_train, y_train), batch_size=batch_size_train, shuffle=False)
test_loader = DataLoader(LinearDataset(X_test, y_test), batch_size=batch_size_test, shuffle=False)

In [5]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        #outputs = torch.sigmoid(self.linear(x))
        outputs = self.linear(x)
        return outputs

input_dim = 3 
output_dim = 1 # Two possible outputs
learning_rate = 0.01

model = LogisticRegression(input_dim,output_dim).to(device)
model.load_state_dict(torch.load("gaussian_gd_0.5.pt"))

<All keys matched successfully>

In [6]:
from scipy.special import lambertw
import math
lambert_w = lambertw(1/math.e)
print(lambert_w)

(0.2784645427610738+0j)


In [7]:
for param in model.parameters():
    w_p = param
    break
print(w_p)

Parameter containing:
tensor([[-0.3458, -2.4712,  0.1935]], device='cuda:0', requires_grad=True)


In [8]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

linear.weight tensor([[-0.3458, -2.4712,  0.1935]], device='cuda:0')
linear.bias tensor([1.2288], device='cuda:0')


In [9]:
# try scaling the weights 
state_dict = model.state_dict()
state_dict['linear.weight'] = torch.tensor([[-0.03458,-0.24712,0.01935]])
model.load_state_dict(state_dict)

<All keys matched successfully>

In [12]:
# the script for training target attack
epsilon = 0.06
lr = 1
epochs = 2000

def adjust_learning_rate(lr, epoch):
    """Decay the learning rate based on schedule"""
    lr *= 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    return(lr)


def autograd(outputs, inputs, create_graph=False):
    """Compute gradient of outputs w.r.t. inputs, assuming outputs is a scalar."""
    #inputs = tuple(inputs)
    grads = torch.autograd.grad(outputs, inputs, create_graph=create_graph, allow_unused=True)
    return [xx if xx is not None else yy.new_zeros(yy.size()) for xx, yy in zip(grads, inputs)]

def train(epoch, X,y):
    data, target = X.to(device), y.to(device)
    data.requires_grad=True
    if epoch==0:
        # initialize poisoned data
        if epsilon<=1:
            data_p = Variable(data[:(int(epsilon*len(data)))])
            target_p = Variable(target[:(int(epsilon*len(target)))])
        else:
            data_p = Variable(torch.cat((data.repeat(int(epsilon),1),data[:(int((epsilon-int(epsilon))*len(data)))]),0))
            target_p = Variable(torch.cat((target.repeat(int(epsilon)),target[:(int((epsilon-int(epsilon))*len(target)))]),0))
        torch.save(target_p,'target_p_gaussian_{}.pt'.format(epsilon))
    else:
        data_p = torch.load('data_p_gaussian_{}.pt'.format(epsilon))
        target_p = torch.load('target_p_gaussian_{}.pt'.format(epsilon))
    data_p.requires_grad=True

    # initialize f function
    criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')

    # calculate gradient of w on clean sample
    output_c = torch.squeeze(model(data))
    loss_c =  0.5 * criterion(output_c,target)

    # calculate dL/dg_1
    grad_c= autograd(loss_c,tuple(model.parameters()),create_graph=True)
    g1 = grad_c[0]
    g_mu = g1/len(y)
    
    # calculate the size of epsilon_d w.r.t g1 and the Lambert's W function
    g_mu_dot_w = np.dot(g_mu.to('cpu').detach().numpy().squeeze(),w_p.cpu().detach().numpy().squeeze())
    if epoch == 0:
        print(g_mu_dot_w)
        print('the necessary size of epsilon_d:{}'.format(g_mu_dot_w/lambert_w))
    

    # calculate gradient of w on poisoned sample
    output_p = torch.squeeze(model(data_p))
    loss_p = 0.5 * criterion(output_p,target_p)
    grad_p= autograd(loss_p,tuple(model.parameters()),create_graph=True)
    g2= grad_p[0]
    
    # calculate the true loss: |g_c + g_p|_{2}
    grad_sum = g1+g2
    loss = torch.norm(grad_sum,2)

    update = autograd(loss,data_p,create_graph=True)
    data_t = data_p - lr * update[0]


    data_t = data_t.to('cuda')

    torch.save(data_t, 'data_p_gaussian_{}.pt'.format(epsilon))
    

    print("epoch:{},lr:{},loss:{}".format(epoch,lr,loss))
    
    return loss
        

In [13]:
for epoch in range(epochs):
    loss = train(epoch,X_train,y_train)
    if loss<0.0001:
        break

0.01681924
the necessary size of epsilon_d:(0.060399935871893046+0j)
epoch:0,lr:1,loss:182.27609252929688
epoch:1,lr:1,loss:178.17230224609375
epoch:2,lr:1,loss:174.17391967773438
epoch:3,lr:1,loss:170.27223205566406
epoch:4,lr:1,loss:166.45852661132812
epoch:5,lr:1,loss:162.72418212890625
epoch:6,lr:1,loss:159.06072998046875
epoch:7,lr:1,loss:155.45999145507812
epoch:8,lr:1,loss:151.91409301757812
epoch:9,lr:1,loss:148.4156494140625
epoch:10,lr:1,loss:144.9577178955078
epoch:11,lr:1,loss:141.53396606445312
epoch:12,lr:1,loss:138.13864135742188
epoch:13,lr:1,loss:134.76675415039062
epoch:14,lr:1,loss:131.4139404296875
epoch:15,lr:1,loss:128.07675170898438
epoch:16,lr:1,loss:124.75247192382812
epoch:17,lr:1,loss:121.43930053710938
epoch:18,lr:1,loss:118.13629150390625
epoch:19,lr:1,loss:114.84346771240234
epoch:20,lr:1,loss:111.56173706054688
epoch:21,lr:1,loss:108.29291534423828
epoch:22,lr:1,loss:105.03974914550781
epoch:23,lr:1,loss:101.80583190917969
epoch:24,lr:1,loss:98.5955810546

epoch:243,lr:1,loss:9.652846336364746
epoch:244,lr:1,loss:9.641124725341797
epoch:245,lr:1,loss:9.629534721374512
epoch:246,lr:1,loss:9.618086814880371
epoch:247,lr:1,loss:9.606769561767578
epoch:248,lr:1,loss:9.595587730407715
epoch:249,lr:1,loss:9.58453369140625
epoch:250,lr:1,loss:9.5736083984375
epoch:251,lr:1,loss:9.562809944152832
epoch:252,lr:1,loss:9.552139282226562
epoch:253,lr:1,loss:9.541584014892578
epoch:254,lr:1,loss:9.53115177154541
epoch:255,lr:1,loss:9.520841598510742
epoch:256,lr:1,loss:9.510648727416992
epoch:257,lr:1,loss:9.500565528869629
epoch:258,lr:1,loss:9.49060344696045
epoch:259,lr:1,loss:9.480746269226074
epoch:260,lr:1,loss:9.471002578735352
epoch:261,lr:1,loss:9.461369514465332
epoch:262,lr:1,loss:9.451842308044434
epoch:263,lr:1,loss:9.442420959472656
epoch:264,lr:1,loss:9.433101654052734
epoch:265,lr:1,loss:9.4238862991333
epoch:266,lr:1,loss:9.414773941040039
epoch:267,lr:1,loss:9.405755996704102
epoch:268,lr:1,loss:9.396841049194336
epoch:269,lr:1,loss

epoch:457,lr:1,loss:10.535860061645508
epoch:458,lr:1,loss:10.657929420471191
epoch:459,lr:1,loss:10.539916038513184
epoch:460,lr:1,loss:10.662890434265137
epoch:461,lr:1,loss:10.54394817352295
epoch:462,lr:1,loss:10.667826652526855
epoch:463,lr:1,loss:10.547944068908691
epoch:464,lr:1,loss:10.672720909118652
epoch:465,lr:1,loss:10.551900863647461
epoch:466,lr:1,loss:10.677557945251465
epoch:467,lr:1,loss:10.555815696716309
epoch:468,lr:1,loss:10.682360649108887
epoch:469,lr:1,loss:10.559691429138184
epoch:470,lr:1,loss:10.687106132507324
epoch:471,lr:1,loss:10.563517570495605
epoch:472,lr:1,loss:10.69179630279541
epoch:473,lr:1,loss:10.56730842590332
epoch:474,lr:1,loss:10.696468353271484
epoch:475,lr:1,loss:10.571078300476074
epoch:476,lr:1,loss:10.701087951660156
epoch:477,lr:1,loss:10.574795722961426
epoch:478,lr:1,loss:10.705658912658691
epoch:479,lr:1,loss:10.57847785949707
epoch:480,lr:1,loss:10.710201263427734
epoch:481,lr:1,loss:10.582131385803223
epoch:482,lr:1,loss:10.714683

epoch:682,lr:1,loss:11.016570091247559
epoch:683,lr:1,loss:10.822733879089355
epoch:684,lr:1,loss:11.018526077270508
epoch:685,lr:1,loss:10.824240684509277
epoch:686,lr:1,loss:11.020484924316406
epoch:687,lr:1,loss:10.825739860534668
epoch:688,lr:1,loss:11.022428512573242
epoch:689,lr:1,loss:10.827219009399414
epoch:690,lr:1,loss:11.024356842041016
epoch:691,lr:1,loss:10.828691482543945
epoch:692,lr:1,loss:11.02625846862793
epoch:693,lr:1,loss:10.830144882202148
epoch:694,lr:1,loss:11.028143882751465
epoch:695,lr:1,loss:10.831594467163086
epoch:696,lr:1,loss:11.030020713806152
epoch:697,lr:1,loss:10.83301830291748
epoch:698,lr:1,loss:11.031885147094727
epoch:699,lr:1,loss:10.83444881439209
epoch:700,lr:1,loss:11.033729553222656
epoch:701,lr:1,loss:10.835858345031738
epoch:702,lr:1,loss:11.035569190979004
epoch:703,lr:1,loss:10.837251663208008
epoch:704,lr:1,loss:11.037382125854492
epoch:705,lr:1,loss:10.838643074035645
epoch:706,lr:1,loss:11.03918743133545
epoch:707,lr:1,loss:10.840022

epoch:925,lr:1,loss:10.94080924987793
epoch:926,lr:1,loss:11.174129486083984
epoch:927,lr:1,loss:10.94139289855957
epoch:928,lr:1,loss:11.174918174743652
epoch:929,lr:1,loss:10.941980361938477
epoch:930,lr:1,loss:11.17570972442627
epoch:931,lr:1,loss:10.942551612854004
epoch:932,lr:1,loss:11.176481246948242
epoch:933,lr:1,loss:10.943124771118164
epoch:934,lr:1,loss:11.177263259887695
epoch:935,lr:1,loss:10.943700790405273
epoch:936,lr:1,loss:11.1780366897583
epoch:937,lr:1,loss:10.944262504577637
epoch:938,lr:1,loss:11.178790092468262
epoch:939,lr:1,loss:10.94482421875
epoch:940,lr:1,loss:11.179542541503906
epoch:941,lr:1,loss:10.945371627807617
epoch:942,lr:1,loss:11.18028736114502
epoch:943,lr:1,loss:10.945921897888184
epoch:944,lr:1,loss:11.18103313446045
epoch:945,lr:1,loss:10.946481704711914
epoch:946,lr:1,loss:11.181779861450195
epoch:947,lr:1,loss:10.94702434539795
epoch:948,lr:1,loss:11.182525634765625
epoch:949,lr:1,loss:10.947563171386719
epoch:950,lr:1,loss:11.18323040008545

epoch:1157,lr:1,loss:10.985217094421387
epoch:1158,lr:1,loss:11.234983444213867
epoch:1159,lr:1,loss:10.985435485839844
epoch:1160,lr:1,loss:11.2352933883667
epoch:1161,lr:1,loss:10.9856538772583
epoch:1162,lr:1,loss:11.235604286193848
epoch:1163,lr:1,loss:10.985864639282227
epoch:1164,lr:1,loss:11.235916137695312
epoch:1165,lr:1,loss:10.986095428466797
epoch:1166,lr:1,loss:11.236226081848145
epoch:1167,lr:1,loss:10.986306190490723
epoch:1168,lr:1,loss:11.236549377441406
epoch:1169,lr:1,loss:10.986531257629395
epoch:1170,lr:1,loss:11.23685073852539
epoch:1171,lr:1,loss:10.986737251281738
epoch:1172,lr:1,loss:11.237130165100098
epoch:1173,lr:1,loss:10.986933708190918
epoch:1174,lr:1,loss:11.237432479858398
epoch:1175,lr:1,loss:10.987151145935059
epoch:1176,lr:1,loss:11.237723350524902
epoch:1177,lr:1,loss:10.987357139587402
epoch:1178,lr:1,loss:11.238037109375
epoch:1179,lr:1,loss:10.987560272216797
epoch:1180,lr:1,loss:11.23830795288086
epoch:1181,lr:1,loss:10.987767219543457
epoch:118

epoch:1380,lr:1,loss:11.256707191467285
epoch:1381,lr:1,loss:10.999988555908203
epoch:1382,lr:1,loss:11.256818771362305
epoch:1383,lr:1,loss:11.00005054473877
epoch:1384,lr:1,loss:11.256917953491211
epoch:1385,lr:1,loss:11.000105857849121
epoch:1386,lr:1,loss:11.257007598876953
epoch:1387,lr:1,loss:11.000160217285156
epoch:1388,lr:1,loss:11.257109642028809
epoch:1389,lr:1,loss:11.000224113464355
epoch:1390,lr:1,loss:11.257211685180664
epoch:1391,lr:1,loss:11.000271797180176
epoch:1392,lr:1,loss:11.25728988647461
epoch:1393,lr:1,loss:11.000327110290527
epoch:1394,lr:1,loss:11.257391929626465
epoch:1395,lr:1,loss:11.000384330749512
epoch:1396,lr:1,loss:11.25749397277832
epoch:1397,lr:1,loss:11.000442504882812
epoch:1398,lr:1,loss:11.257562637329102
epoch:1399,lr:1,loss:11.000479698181152
epoch:1400,lr:1,loss:11.25765323638916
epoch:1401,lr:1,loss:11.000536918640137
epoch:1402,lr:1,loss:11.257744789123535
epoch:1403,lr:1,loss:11.000577926635742
epoch:1404,lr:1,loss:11.257814407348633
epoc

epoch:1619,lr:1,loss:11.001333236694336
epoch:1620,lr:1,loss:11.26095199584961
epoch:1621,lr:1,loss:11.00129508972168
epoch:1622,lr:1,loss:11.260924339294434
epoch:1623,lr:1,loss:11.001267433166504
epoch:1624,lr:1,loss:11.260908126831055
epoch:1625,lr:1,loss:11.001239776611328
epoch:1626,lr:1,loss:11.260891914367676
epoch:1627,lr:1,loss:11.001218795776367
epoch:1628,lr:1,loss:11.260885238647461
epoch:1629,lr:1,loss:11.00118350982666
epoch:1630,lr:1,loss:11.260848045349121
epoch:1631,lr:1,loss:11.001153945922852
epoch:1632,lr:1,loss:11.260843276977539
epoch:1633,lr:1,loss:11.001127243041992
epoch:1634,lr:1,loss:11.260814666748047
epoch:1635,lr:1,loss:11.0010986328125
epoch:1636,lr:1,loss:11.260809898376465
epoch:1637,lr:1,loss:11.00107192993164
epoch:1638,lr:1,loss:11.260782241821289
epoch:1639,lr:1,loss:11.001029014587402
epoch:1640,lr:1,loss:11.260757446289062
epoch:1641,lr:1,loss:11.001018524169922
epoch:1642,lr:1,loss:11.260749816894531
epoch:1643,lr:1,loss:11.000975608825684
epoch:

epoch:1849,lr:1,loss:10.995741844177246
epoch:1850,lr:1,loss:11.255754470825195
epoch:1851,lr:1,loss:10.995667457580566
epoch:1852,lr:1,loss:11.255680084228516
epoch:1853,lr:1,loss:10.995609283447266
epoch:1854,lr:1,loss:11.255614280700684
epoch:1855,lr:1,loss:10.99553394317627
epoch:1856,lr:1,loss:11.255537986755371
epoch:1857,lr:1,loss:10.995468139648438
epoch:1858,lr:1,loss:11.255463600158691
epoch:1859,lr:1,loss:10.995399475097656
epoch:1860,lr:1,loss:11.25540828704834
epoch:1861,lr:1,loss:10.995343208312988
epoch:1862,lr:1,loss:11.255343437194824
epoch:1863,lr:1,loss:10.995268821716309
epoch:1864,lr:1,loss:11.255256652832031
epoch:1865,lr:1,loss:10.995193481445312
epoch:1866,lr:1,loss:11.255181312561035
epoch:1867,lr:1,loss:10.995135307312012
epoch:1868,lr:1,loss:11.255126953125
epoch:1869,lr:1,loss:10.995061874389648
epoch:1870,lr:1,loss:11.25503921508789
epoch:1871,lr:1,loss:10.994987487792969
epoch:1872,lr:1,loss:11.254976272583008
epoch:1873,lr:1,loss:10.994930267333984
epoch:

In [14]:
epsilon= 0.06
data_p = torch.load('data_p_gaussian_{}.pt'.format(epsilon))
target_p = torch.load('target_p_gaussian_{}.pt'.format(epsilon))
#print(data_p)

In [15]:
print(data_p.size())

torch.Size([40, 3])


In [16]:
X_all = torch.cat((X_train.to(device),data_p),0).to('cpu')
y_all = torch.cat((y_train.to(device),target_p),0).to('cpu')
print(X_all[:,0].size())

torch.Size([710])


In [17]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

linear.weight tensor([[-0.0346, -0.2471,  0.0193]], device='cuda:0')
linear.bias tensor([1.2288], device='cuda:0')


In [18]:
train_loader_all = DataLoader(LinearDataset(X_all, y_all), batch_size=batch_size_train, shuffle=False)
device = 'cuda'
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [19]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        criterion = torch.nn.BCEWithLogitsLoss()
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        #output = model(data)
        loss = criterion(torch.squeeze(output), target)
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


In [20]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            criterion = torch.nn.BCEWithLogitsLoss()
            data, target = data.to(device), target.to(device)
            #output = model(data.view(data.size(0), -1))
            output = torch.squeeze(model(data))
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = torch.squeeze(output).round()  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [21]:
model1 = LogisticRegression(input_dim,output_dim).to(device)
optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.1)

for epoch in range(100):
    train(model1, device, train_loader_all,optimizer1,epoch)
    test(model1,device, test_loader)

2it [00:00, 62.81it/s]



Test set: Average loss: 0.0026, Accuracy: 10/330 (3.03%)



2it [00:00, 75.62it/s]



Test set: Average loss: 0.0026, Accuracy: 10/330 (3.03%)



2it [00:00, 93.47it/s]



Test set: Average loss: 0.0026, Accuracy: 7/330 (2.12%)



2it [00:00, 85.17it/s]



Test set: Average loss: 0.0026, Accuracy: 6/330 (1.82%)



2it [00:00, 85.73it/s]



Test set: Average loss: 0.0026, Accuracy: 5/330 (1.52%)



2it [00:00, 91.46it/s]



Test set: Average loss: 0.0026, Accuracy: 3/330 (0.91%)



2it [00:00, 88.74it/s]



Test set: Average loss: 0.0026, Accuracy: 1/330 (0.30%)



2it [00:00, 66.26it/s]



Test set: Average loss: 0.0026, Accuracy: 1/330 (0.30%)



2it [00:00, 82.68it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 96.36it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 82.71it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 70.87it/s]







Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 96.83it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 95.39it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



2it [00:00, 90.17it/s]



Test set: Average loss: 0.0025, Accuracy: 0/330 (0.00%)



0it [00:00, ?it/s]



2it [00:00, 89.30it/s]



Test set: Average loss: 0.0025, Accuracy: 7/330 (2.12%)



2it [00:00, 89.19it/s]



Test set: Average loss: 0.0025, Accuracy: 12/330 (3.64%)



2it [00:00, 95.68it/s]



Test set: Average loss: 0.0025, Accuracy: 18/330 (5.45%)



2it [00:00, 92.36it/s]







Test set: Average loss: 0.0025, Accuracy: 25/330 (7.58%)



2it [00:00, 98.00it/s]



Test set: Average loss: 0.0025, Accuracy: 34/330 (10.30%)



2it [00:00, 93.78it/s]



Test set: Average loss: 0.0025, Accuracy: 46/330 (13.94%)



2it [00:00, 84.35it/s]



Test set: Average loss: 0.0025, Accuracy: 60/330 (18.18%)



2it [00:00, 92.39it/s]







Test set: Average loss: 0.0025, Accuracy: 65/330 (19.70%)



2it [00:00, 89.23it/s]



Test set: Average loss: 0.0025, Accuracy: 70/330 (21.21%)



2it [00:00, 87.28it/s]



Test set: Average loss: 0.0025, Accuracy: 75/330 (22.73%)



0it [00:00, ?it/s]



2it [00:00, 83.96it/s]



Test set: Average loss: 0.0025, Accuracy: 76/330 (23.03%)



2it [00:00, 95.36it/s]



Test set: Average loss: 0.0025, Accuracy: 82/330 (24.85%)



2it [00:00, 90.44it/s]



Test set: Average loss: 0.0025, Accuracy: 86/330 (26.06%)



0it [00:00, ?it/s]



2it [00:00, 96.68it/s]



Test set: Average loss: 0.0025, Accuracy: 90/330 (27.27%)



0it [00:00, ?it/s]



2it [00:00, 96.95it/s]



Test set: Average loss: 0.0025, Accuracy: 92/330 (27.88%)



0it [00:00, ?it/s]



2it [00:00, 90.62it/s]



Test set: Average loss: 0.0025, Accuracy: 93/330 (28.18%)



0it [00:00, ?it/s]



2it [00:00, 81.38it/s]



Test set: Average loss: 0.0025, Accuracy: 93/330 (28.18%)



2it [00:00, 95.40it/s]







Test set: Average loss: 0.0025, Accuracy: 96/330 (29.09%)



2it [00:00, 79.06it/s]



Test set: Average loss: 0.0025, Accuracy: 97/330 (29.39%)



2it [00:00, 100.36it/s]







Test set: Average loss: 0.0025, Accuracy: 98/330 (29.70%)



2it [00:00, 62.01it/s]



Test set: Average loss: 0.0025, Accuracy: 100/330 (30.30%)



2it [00:00, 95.93it/s]



Test set: Average loss: 0.0025, Accuracy: 101/330 (30.61%)



2it [00:00, 80.28it/s]



Test set: Average loss: 0.0025, Accuracy: 101/330 (30.61%)



2it [00:00, 80.86it/s]



Test set: Average loss: 0.0025, Accuracy: 102/330 (30.91%)



2it [00:00, 93.36it/s]



Test set: Average loss: 0.0024, Accuracy: 104/330 (31.52%)



2it [00:00, 96.11it/s]



Test set: Average loss: 0.0024, Accuracy: 104/330 (31.52%)



2it [00:00, 75.23it/s]



Test set: Average loss: 0.0024, Accuracy: 107/330 (32.42%)



2it [00:00, 93.20it/s]



Test set: Average loss: 0.0024, Accuracy: 108/330 (32.73%)



2it [00:00, 93.27it/s]



Test set: Average loss: 0.0024, Accuracy: 109/330 (33.03%)



0it [00:00, ?it/s]



2it [00:00, 94.92it/s]



Test set: Average loss: 0.0024, Accuracy: 111/330 (33.64%)



2it [00:00, 84.23it/s]







Test set: Average loss: 0.0024, Accuracy: 112/330 (33.94%)



0it [00:00, ?it/s]



2it [00:00, 79.27it/s]



Test set: Average loss: 0.0024, Accuracy: 114/330 (34.55%)



2it [00:00, 70.49it/s]



Test set: Average loss: 0.0024, Accuracy: 114/330 (34.55%)



2it [00:00, 96.25it/s]



Test set: Average loss: 0.0024, Accuracy: 114/330 (34.55%)



2it [00:00, 95.35it/s]



Test set: Average loss: 0.0024, Accuracy: 114/330 (34.55%)



2it [00:00, 96.62it/s]



Test set: Average loss: 0.0024, Accuracy: 115/330 (34.85%)



2it [00:00, 96.08it/s]







Test set: Average loss: 0.0024, Accuracy: 115/330 (34.85%)



2it [00:00, 96.93it/s]



Test set: Average loss: 0.0024, Accuracy: 116/330 (35.15%)



2it [00:00, 73.99it/s]



Test set: Average loss: 0.0024, Accuracy: 116/330 (35.15%)



2it [00:00, 86.54it/s]



Test set: Average loss: 0.0024, Accuracy: 116/330 (35.15%)



0it [00:00, ?it/s]



2it [00:00, 81.77it/s]



Test set: Average loss: 0.0024, Accuracy: 119/330 (36.06%)



2it [00:00, 89.00it/s]



Test set: Average loss: 0.0024, Accuracy: 119/330 (36.06%)



2it [00:00, 78.34it/s]



Test set: Average loss: 0.0024, Accuracy: 119/330 (36.06%)



2it [00:00, 95.79it/s]



Test set: Average loss: 0.0024, Accuracy: 119/330 (36.06%)



2it [00:00, 92.40it/s]



Test set: Average loss: 0.0024, Accuracy: 120/330 (36.36%)



2it [00:00, 82.39it/s]



Test set: Average loss: 0.0024, Accuracy: 120/330 (36.36%)



2it [00:00, 88.78it/s]



Test set: Average loss: 0.0024, Accuracy: 120/330 (36.36%)



0it [00:00, ?it/s]



2it [00:00, 83.06it/s]



Test set: Average loss: 0.0024, Accuracy: 120/330 (36.36%)



2it [00:00, 85.83it/s]



Test set: Average loss: 0.0024, Accuracy: 120/330 (36.36%)



2it [00:00, 88.11it/s]



Test set: Average loss: 0.0024, Accuracy: 121/330 (36.67%)



2it [00:00, 95.82it/s]



Test set: Average loss: 0.0024, Accuracy: 121/330 (36.67%)



2it [00:00, 84.71it/s]



Test set: Average loss: 0.0024, Accuracy: 121/330 (36.67%)



2it [00:00, 81.37it/s]



Test set: Average loss: 0.0024, Accuracy: 121/330 (36.67%)



2it [00:00, 68.69it/s]


Test set: Average loss: 0.0024, Accuracy: 122/330 (36.97%)




2it [00:00, 92.45it/s]



Test set: Average loss: 0.0024, Accuracy: 123/330 (37.27%)



2it [00:00, 89.81it/s]



Test set: Average loss: 0.0024, Accuracy: 124/330 (37.58%)



2it [00:00, 96.36it/s]



Test set: Average loss: 0.0024, Accuracy: 125/330 (37.88%)



2it [00:00, 95.85it/s]



Test set: Average loss: 0.0024, Accuracy: 125/330 (37.88%)



2it [00:00, 96.99it/s]



Test set: Average loss: 0.0024, Accuracy: 127/330 (38.48%)



2it [00:00, 100.56it/s]







Test set: Average loss: 0.0024, Accuracy: 127/330 (38.48%)



2it [00:00, 87.70it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 90.72it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 89.27it/s]







Test set: Average loss: 0.0024, Accuracy: 127/330 (38.48%)



2it [00:00, 100.21it/s]







Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 101.92it/s]







Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 88.47it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 100.03it/s]







Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 90.12it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



0it [00:00, ?it/s]



2it [00:00, 93.88it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 88.70it/s]







Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 75.73it/s]







Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



0it [00:00, ?it/s]



2it [00:00, 77.73it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 91.39it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 99.03it/s]







Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 98.01it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 94.86it/s]







Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



0it [00:00, ?it/s]



2it [00:00, 21.26it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 72.36it/s]



Test set: Average loss: 0.0024, Accuracy: 128/330 (38.79%)



2it [00:00, 105.22it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 91.02it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 93.23it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 104.51it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 78.21it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 61.21it/s]



Test set: Average loss: 0.0024, Accuracy: 129/330 (39.09%)



2it [00:00, 88.57it/s]


Test set: Average loss: 0.0024, Accuracy: 130/330 (39.39%)






In [22]:
for name, param in model1.named_parameters():
    if param.requires_grad:
        print(name, param.data)

linear.weight tensor([[-0.1849, -0.3146,  0.0776]], device='cuda:0')
linear.bias tensor([0.1418], device='cuda:0')


In [23]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

linear.weight tensor([[-0.0346, -0.2471,  0.0193]], device='cuda:0')
linear.bias tensor([1.2288], device='cuda:0')
