# Загружаем картинки рыбок и котиков из сети

In [3]:
import os
import sys
import urllib3
from urllib.parse import urlparse
import pandas as pd
import itertools
import shutil

from urllib3.util import Retry

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

classes = ["cat", "fish"]
set_types = ["train", "test", "val"]

def download_image(url, klass, data_type):
    basename = os.path.basename(urlparse(url).path)
    filename = "{}/{}/{}".format(data_type, klass, basename)
    if not os.path.exists(filename):
        try: 
            http = urllib3.PoolManager(retries=Retry(connect=1, read=1, redirect=2))
            with http.request("GET", url, preload_content=False) as resp, open(
                filename, "wb"
            ) as out_file:
                if resp.status == 200:
                    shutil.copyfileobj(resp, out_file)
                else:
                    print("Error downloading {}".format(url))
            resp.release_conn()
        except:
            print("Error downloading {}".format(url))


if __name__ == "__main__":
    if not os.path.exists("images.csv"):
        print("Error: can't find images.csv!")
        sys.exit(0)

    # get args and create output directory
    imagesDF = pd.read_csv("images.csv")

    for set_type, klass in list(itertools.product(set_types, classes)):
        path = "./{}/{}".format(set_type, klass)
        if not os.path.exists(path):
            print("Creating directory {}".format(path))
            os.makedirs(path)

    print("Downloading {} images".format(len(imagesDF)))

    result = [
        download_image(url, klass, data_type)
        for url, klass, data_type in zip(
            imagesDF["url"], imagesDF["class"], imagesDF["type"]
        )
    ]
    sys.exit(0)


Creating directory ./train/cat
Creating directory ./train/fish
Creating directory ./test/cat
Creating directory ./test/fish
Creating directory ./val/cat
Creating directory ./val/fish
Downloading 1393 images
Error downloading http://farm1.static.flickr.com/1/1004525_cba96ba3c3.jpg
Error downloading http://farm1.static.flickr.com/33/65048097_e5264bf855.jpg
Error downloading http://farm3.static.flickr.com/2354/2102976081_61c8614be8.jpg
Error downloading http://farm2.static.flickr.com/1390/709949156_5e4ac3f499.jpg
Error downloading http://farm2.static.flickr.com/1221/1011749126_44a195db4c.jpg
Error downloading http://farm1.static.flickr.com/173/416994740_6ada308baa.jpg
Error downloading http://farm1.static.flickr.com/3/4193130_a058cdb81f.jpg
Error downloading http://farm1.static.flickr.com/1/1004528_a111209743.jpg
Error downloading http://farm1.static.flickr.com/223/522289072_13b4f92d39.jpg
Error downloading http://farm1.static.flickr.com/28/100739452_5be2c11557.jpg
Error downloading http:

Error downloading http://farm4.static.flickr.com/3183/2791133165_5df1d47be5.jpg
Error downloading http://farm4.static.flickr.com/3088/2776496087_1973f8dced.jpg
Error downloading http://farm4.static.flickr.com/3224/2777354896_176a518b8c.jpg
Error downloading http://farm2.static.flickr.com/1103/840200833_be72b99848.jpg
Error downloading http://farm2.static.flickr.com/1226/1140675688_2498ebdcc7.jpg
Error downloading http://farm4.static.flickr.com/3005/3007583030_bd590c07e7.jpg
Error downloading http://farm4.static.flickr.com/3218/2734977928_8d16b48c0a.jpg
Error downloading http://farm4.static.flickr.com/3218/2776494507_7e4cd3a67e.jpg
Error downloading http://farm2.static.flickr.com/1391/1087447360_c6037a47f5.jpg
Error downloading http://farm4.static.flickr.com/3124/2297032796_b5eb52b860.jpg
Error downloading http://farm4.static.flickr.com/3183/2806314103_c6a27c6a53.jpg
Error downloading http://farm3.static.flickr.com/2356/2076921635_8146b3766d.jpg
Error downloading http://farm4.static.fli

Error downloading http://michiganstreamside.com/upcoming2.jpg
Error downloading http://www.cordovarose.com/images/silversalmon.jpg
Error downloading http://www.driftingonthefly.com/images/rainbows/sctrnbw06_07.jpg
Error downloading http://www.tyeeatercharters.com/images/coho_salmon.jpg
Error downloading http://www.skeetchestn.ca/Natural%20Resources%20Website/SARAImages/CSalmon.jpg
Error downloading http://www2.kpr.edu.on.ca/cdciw/biomes/king1.jpg
Error downloading http://www.mattfender.net/alaska/dave%20with%20two%20silvers.jpg
Error downloading http://bonshellfishing.com/G.%20L.%20Salmon%20&%20Trout%20Description%20Pg%202007/Coho%20Salmon%20%20IMG_0970.JPG
Error downloading http://www.orcalodge.com/13.jpg
Error downloading http://www.alaskanfishingadventures.com/PB280171.JPG
Error downloading http://img5.travelblog.org/Photos/45963/199454/t/1491860-Big-silver-0.jpg
Error downloading http://www.zenwaiterwest.com/photos/july22-30%202005%20Salmon%20Fishing%20%20Ch%202/coho%20spring%20lin

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# Инициализируем сеть

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

In [4]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [5]:
img_transforms = transforms.Compose([
    transforms.Resize((64,64)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225] )
    ])

In [6]:
train_data_path = "./train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)

val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)

test_data_path = "./test/"
test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image)

batch_size=64

train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size) 
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

In [7]:
class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50,2)
    
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
simplenet = SimpleNet()

In [9]:
simplenet.load_state_dict(torch.load("./tmp/simplenet/simplenet.pth"))

<All keys matched successfully>

In [10]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

In [11]:
check_image(val_data_path)

False

In [12]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

simplenet.to(device)

SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

# Тренеровка сети

In [21]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [22]:
train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)

Epoch: 1, Training Loss: 1.83, Validation Loss: 5.49, accuracy = 0.35
Epoch: 2, Training Loss: 2.74, Validation Loss: 0.68, accuracy = 0.73
Epoch: 3, Training Loss: 0.68, Validation Loss: 0.69, accuracy = 0.63
Epoch: 4, Training Loss: 0.60, Validation Loss: 0.61, accuracy = 0.75
Epoch: 5, Training Loss: 0.43, Validation Loss: 0.63, accuracy = 0.68


In [15]:
labels = ['cat','fish']

img = Image.open("./val/fish/SteveSilver2_1.jpg") 
img = img_transforms(img).to(device)

simplenet.eval()
prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction])

fish


In [43]:
torch.save(simplenet.state_dict(), os.path.join("./tmp/simplenet", "simplenet.pth"))

In [82]:
#simplenet = torch.load("./tmp/simplenet/simplenet.pth")
simplenet.load_state_dict(torch.load("./tmp/simplenet/simplenet.pth"))
#simplenet.eval()

RuntimeError: CUDA error: an illegal memory access was encountered