In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction import stop_words
from sklearn.model_selection import train_test_split
from string import punctuation
from collections import Counter
import matplotlib.pyplot as plt
import pickle
import numpy as np
from PIL import Image
stop_words = list(stop_words.ENGLISH_STOP_WORDS) + list(punctuation)

import torchvision.models as models
from torchvision import transforms
import torch

from src.data import Newsgroups

In [2]:
hparams = {
    'n_features' : 10000
}

### CIFAR-100

In [4]:
cifar100 = Cifar100(batch_size=64)

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /home/ec2-user/.cache/torch/checkpoints/resnet152-b121ed2d.pth


HBox(children=(FloatProgress(value=0.0, max=241530880.0), HTML(value='')))


Extracting train resnet features...
Extracting test resnet features...
10000

In [5]:
data = {
    'class_weights' : cifar100.class_weights,
    'flat_label_dict' : cifar100.flat_label_dict,
    'hier_label_dict' : cifar100.hier_label_dict,
    'train' : {
        'images' : cifar100.train_images,
        'features' : cifar100.train_features,
        'flat_labels' : cifar100.train_flat_labels,
        'hier_labels' : cifar100.train_hier_labels
    },
    'test' : {
        'images' : cifar100.test_images,
        'features' : cifar100.test_features,
        'flat_labels' : cifar100.test_flat_labels,
        'hier_labels' : cifar100.test_hier_labels
    }
}

pickle.dump(data, open('data/cifar100-resnet152.pickle', 'wb'))

### Logistic Regression

In [6]:
from sklearn.linear_model import LogisticRegression

##### 20Newsgroups

In [19]:
clf = LogisticRegression(class_weight='balanced')

In [20]:
clf.fit(train_features, train_targets)



LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='warn', tol=0.0001, verbose=0, warm_start=False)

In [21]:
clf.score(train_features, train_targets)

0.9643133355430339

In [22]:
clf.score(val_features, val_targets)

0.8824569155987627

In [23]:
clf.score(test_features, test.target)

0.8068242166755177

##### Cifar100

In [7]:
clf = LogisticRegression(class_weight='balanced')

In [8]:
clf.fit(cifar100.train_features, cifar100.train_flat_labels)



LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='warn', tol=0.0001, verbose=0, warm_start=False)

In [9]:
clf.score(cifar100.train_features, cifar100.train_flat_labels)

0.9878

In [10]:
clf.score(cifar100.test_features, cifar100.test_flat_labels)

0.6097

### Supervised Graph Classifier

In [27]:
from sklearn_hierarchical_classification.constants import ROOT
from sklearn.linear_model import LogisticRegression
from sklearn_hierarchical_classification.classifier import HierarchicalClassifier
import networkx as nx

##### 20Newsgroups Tree

In [29]:
data = Newsgroups()
graph = nx.relabel_nodes(data.tree, {'ROOT' : ROOT})

train_labels = [data.tree_label_dict[x][-1] for x in data.train_flat_labels]
test_labels = [data.tree_label_dict[x][-1] for x in data.test_flat_labels]

In [55]:
clf = HierarchicalClassifier(
    base_estimator=LogisticRegression(class_weight='balanced'),
    class_hierarchy=graph
)

In [56]:
clf.fit(data.train_features, train_labels)



HierarchicalClassifier(algorithm='lcpn',
            base_estimator=LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='warn', tol=0.0001, verbose=0, warm_start=False),
            class_hierarchy=<networkx.classes.digraph.DiGraph object at 0x7ff02af25b38>,
            prediction_depth='mlnp', progress_wrapper=None, root='<ROOT>',
            stopping_criteria=None, training_strategy=None)

In [93]:
hier_preds = [clf._recursive_predict(x[None, :], clf.root)[0] for x in data.test_features]

In [34]:
clf.score(data.test_features, test_tree_labels)

0.8008497079129049

##### 20Newsgroups DAG

In [101]:
data = Newsgroups()
graph = nx.relabel_nodes(data.dag, {'ROOT' : ROOT})

train_labels = [data.dag_label_dict[x][-1] for x in data.train_flat_labels]
test_labels = [data.dag_label_dict[x][-1] for x in data.test_flat_labels]

In [102]:
clf = HierarchicalClassifier(
    base_estimator=LogisticRegression(class_weight='balanced'),
    class_hierarchy=graph
)

In [103]:
clf.fit(data.train_features, train_labels)



HierarchicalClassifier(algorithm='lcpn',
            base_estimator=LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='warn', tol=0.0001, verbose=0, warm_start=False),
            class_hierarchy=<networkx.classes.digraph.DiGraph object at 0x7ff0b0201470>,
            prediction_depth='mlnp', progress_wrapper=None, root='<ROOT>',
            stopping_criteria=None, training_strategy=None)

In [104]:
hier_preds = [clf._recursive_predict(x[None, :], clf.root)[0] for x in data.test_features]

In [105]:
clf.score(data.test_features, test_labels)

0.6732607541157727