# Simple Neural Network using PyTorch #

Author: ****Pier Luca Anania****

In [3]:
pip install torch

Note: you may need to restart the kernel to use updated packages.


In this example we're going to use **Iris Dataset**

In [6]:
import torch
import torch.nn as nn                   #Neural Network
import torch.nn.functional as F         #to move data forward in our function

Create a model class that inherits _nn.Module_

In [13]:
class Model(nn.Module):
    
    '''Structure:
    - Input Layer (4 features of the flower: sepal lenght/width, petal lenght/width)
    - Hidden Layer 1 := H1 (# of neurons)
    - H2 (# of neurons)
    - Output Layer (3 classes of Iris flowers: Iris Setosa/Versicolour/Virginica)
    '''
    def __init__(self, in_features=4, h1=8, h2=9, out_features=3):
        
        super().__init__()                                 #instantiate our nn.Module
        self.fc1 = nn.Linear(in_features, h1)              #fc := fully connected
        self.fc2 = nn.Linear(h1, h2)
        self.out = nn.Linear(h2, out_features)
    
    def forward(self, x):       #function to move everything forward
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        
        return x


Pick a manual seed for randomization

In [14]:
torch.manual_seed(41)

<torch._C.Generator at 0x162894c8f90>

Create an instance of our Model

In [15]:
model = Model()

## Working with Data

In [24]:
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline 

url = 'https://gist.github.com/netj/8836201.js'
my_df = pd.read_csv(url)
my_df

Unnamed: 0,"document.write('<link rel=""stylesheet"" href=""https://github.githubassets.com/assets/gist-embed-f65f23c5975e.css"">')"
"document.write('<div id=\""gist8836201\"" class=\""gist\"">\n <div class=\""gist-file\"" translate=\""no\"">\n <div class=\""gist-data\"">\n <div class=\""js-gist-file-update-container js-task-list-container file-box\"">\n <div id=\""file-iris-csv\"" class=\""file my-2\"">\n \n <div itemprop=\""text\"" class=\""Box-body p-0 blob-wrapper data type-csv \"">\n\n <div class=\""blob-interaction-bar\"">\n <svg aria-hidden=\""true\"" height=\""16\"" viewBox=\""0 0 16 16\"" version=\""1.1\"" width=\""16\"" data-view-component=\""true\"" class=\""octicon octicon-search\"">\n <path d=\""M10.68 11.74a6 6 0 0 1-7.922-8.982 6 6 0 0 1 8.982 7.922l3.04 3.04a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215ZM11.5 7a4.499 4.499 0 1 0-8.997 0A4.499 4.499 0 0 0 11.5 7Z\""><\/path>\n<\/svg>\n <input type=\""text\"" name=\""filter\"" class=\""form-control js-csv-filter-field blob-filter\"" autocapitalize=\""off\""\n placeholder=\""Search this file…\"" aria-label=\""Search this file…\"">\n<\/div>\n\n <div class=\""markdown-body js-check-bidi\"" data-line-alert=\""before\"" data-hpc>\n <template class=\""js-file-alert-template\"">\n <div data-view-component=\""true\"" class=\""flash flash-warn flash-full d-flex flex-items-center\"">\n <svg aria-hidden=\""true\"" height=\""16\"" viewBox=\""0 0 16 16\"" version=\""1.1\"" width=\""16\"" data-view-component=\""true\"" class=\""octicon octicon-alert\"">\n <path d=\""M6.457 1.047c.659-1.234 2.427-1.234 3.086 0l6.082 11.378A1.75 1.75 0 0 1 14.082 15H1.918a1.75 1.75 0 0 1-1.543-2.575Zm1.763.707a.25.25 0 0 0-.44 0L1.698 13.132a.25.25 0 0 0 .22.368h12.164a.25.25 0 0 0 .22-.368Zm.53 3.996v2.5a.75.75 0 0 1-1.5 0v-2.5a.75.75 0 0 1 1.5 0ZM9 11a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z\""><\/path>\n<\/svg>\n <span>\n This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review",open the file in an editor that reveals hidde...
