# ATMSPY_DANN
套用github上的模板，然后用在自己采集的数据使用DANN

In [1]:
import random
import os
import sys
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
from data_loader import GetLoader
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import TensorDataset
from mymodel import MyCNNModel
from mytest import mytest
#from test import test

In [2]:
source_x=np.loadtxt('collect_data/pzc/05_25_21_16/freqDomain/freqSignal.csv',delimiter=',',dtype=np.float32)
source_y=np.loadtxt('collect_data/pzc/05_25_21_16/label.csv',delimiter=',',dtype=np.int64)
target_x=np.loadtxt('collect_data/qjw/05_25_21_30/freqDomain/freqSignal.csv',delimiter=',',dtype=np.float32)
target_y=np.loadtxt('collect_data/qjw/05_25_21_30/label.csv',delimiter=',',dtype=np.int64)


In [8]:
cuda = True
cudnn.benchmark = True
lr = 1e-3
batch_size = 128
image_size = 28
n_epoch = 100
source_dataset_name='source'
target_dataset_name='target'
model_root='models'
manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

<torch._C.Generator at 0x7eff1813c710>

In [4]:
source_x=torch.from_numpy(source_x)
source_y=torch.from_numpy(source_y)
target_x=torch.from_numpy(target_x)
target_y=torch.from_numpy(target_y)
dataset_source=TensorDataset(source_x,source_y)
dataset_target=TensorDataset(target_x,target_y)


In [5]:
dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8)

dataloader_target = torch.utils.data.DataLoader(
    dataset=dataset_target,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8)

In [6]:
# load model

my_net = MyCNNModel()

# setup optimizer

optimizer = optim.Adam(my_net.parameters(), lr=lr)

loss_class = torch.nn.NLLLoss()
loss_domain = torch.nn.NLLLoss()

if cuda:
    my_net = my_net.cuda()
    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()

for p in my_net.parameters():
    p.requires_grad = True


In [9]:
# training
best_accu_t = 0.0
for epoch in range(n_epoch):

    len_dataloader = min(len(dataloader_source), len(dataloader_target))
    data_source_iter = iter(dataloader_source)
    data_target_iter = iter(dataloader_target)

    for i in range(len_dataloader):

        p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # training model using source data
        data_source = data_source_iter.next()
        s_img, s_label = data_source

        my_net.zero_grad()
        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size).long()

        if cuda:
            s_img = s_img.cuda()
            s_label = s_label.cuda()
            domain_label = domain_label.cuda()


        class_output, domain_output = my_net(input_data=s_img, alpha=alpha)
        err_s_label = loss_class(class_output, s_label)
        err_s_domain = loss_domain(domain_output, domain_label)

        # training model using target data
        data_target = data_target_iter.next()
        t_img, _ = data_target

        batch_size = len(t_img)

        domain_label = torch.ones(batch_size).long()

        if cuda:
            t_img = t_img.cuda()
            domain_label = domain_label.cuda()

        _, domain_output = my_net(input_data=t_img, alpha=alpha)
        err_t_domain = loss_domain(domain_output, domain_label)
        err = err_t_domain + err_s_domain + err_s_label
        err.backward()
        optimizer.step()

        sys.stdout.write('\r epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \
              % (epoch, i + 1, len_dataloader, err_s_label.data.cpu().numpy(),
                 err_s_domain.data.cpu().numpy(), err_t_domain.data.cpu().item()))
        sys.stdout.flush()
        torch.save(my_net, '{0}/mnist_mnistm_model_epoch_current.pth'.format(model_root))

    print('\n')
    accu_s = mytest(source_dataset_name)
    print('Accuracy of the %s dataset: %f' % ('source: ', accu_s))
    accu_t = mytest(target_dataset_name)
    print('Accuracy of the %s dataset: %f\n' % ('target: ', accu_t))
    if accu_t > best_accu_t:
        best_accu_s = accu_s
        best_accu_t = accu_t
        torch.save(my_net, '{0}/mnist_mnistm_model_epoch_best.pth'.format(model_root))

 epoch: 0, [iter: 3 / all 3], err_s_label: 1.985848, err_s_domain: 0.734522, err_t_domain: 0.816328



AssertionError: 