In [None]:
import datajoint as dj

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
schema = dj.schema('states')

In [None]:
states = dict(
    AL='Alabama',        AK='Alaska',       AZ='Arizona',      AR='Arkansas',
    CA='California',     CO='Colorado',     CT='Connecticut',  DE='Delaware',
    FL='Florida',        GA='Georgia',      HI='Hawaii',       ID='Idaho', 
    IL='Illinois',       IN='Indiana',      IA='Iowa',         KS='Kansas',
    KY='Kentucky',       LA='Louisiana',    ME='Maine',        MD='Maryland',
    MA='Massachusetts',  MI='Michigan',     MN='Minnesota',    MS='Mississippi',
    MO='Missouri',       MT='Montana',      NE='Nebraska',     NV='Nevada',
    NH='New Hampshire',  NJ='New Jersey',   NM='New Mexico',   NY='New York',
    NC='North Carolina', ND='North Dakota', OH='Ohio',         OK='Oklahoma',
    OR='Oregon',         PA='Pennsylvania', RI='Rhode Island', SC='South Carolina',
    SD='South Dakota',   TN='Tennessee',    TX='Texas',        UT='Utah',
    VT='Vermont',        VA='Virginia',     WA='Washington',   WV='West Virginia', 
    WI='Wisconsin',      WY='Wyoming')

In [None]:
@schema
class State(dj.Lookup):
    definition = """
    # United States
    state_code : char(2)
    ---
    state : varchar(20)
    """
    contents = states.items()

In [None]:
State()

In [None]:
dj.Diagram(schema)

In [None]:
@schema
class StateBird(dj.Imported):
    definition = """
    -> State
    ---
    bird_image : longblob    
    """
    
    url_template = "http://www.theus50.com/images/state-birds/{state}-bird.jpg"
    
    def make(self, key):
        # fetch data upstream
        state = (State & key).fetch1('state')
        
        # compute
        url = self.url_template.format(state=state.lower().replace(' ', ''))
        temp_file = 'tmp.jpg'
        with open(temp_file, 'wb') as f:
            f.write(requests.get(url).content)
        
        # insert into self
        self.insert1(dict(key, bird_image=plt.imread(temp_file)))

In [None]:
StateBird().populate(display_progress=True, suppress_errors=True)

In [None]:
StateBird()

In [None]:
img = (StateBird & {'state_code': "TX"}).fetch1('bird_image')

In [None]:
plt.imshow(img)

In [None]:
@schema
class StateFlag(dj.Imported):
    definition = """
    -> State
    ---
    flag_image : longblob    
    """
    
    url_template = "http://www.theus50.com/images/state-flags/{state}-flag.jpg"
    
    def make(self, key):
        state = (State & key).fetch1('state')
        url = self.url_template.format(state=state.lower().replace(' ', ''))
        temp_file = 'tmp.jpg'
        with open(temp_file, 'wb') as f:
            f.write(requests.get(url).content)
        self.insert1(dict(key, flag_image=plt.imread(temp_file)))

In [None]:
StateFlag.populate(display_progress=True, suppress_errors=True)

In [None]:
@schema
class StateFlower(dj.Imported):
    definition = """
    -> State
    ---
    flower_image : longblob    
    """
    
    url_template = "http://www.theus50.com/images/state-flowers/{state}-flower.jpg"
    
    def make(self, key):
        state = (State & key).fetch1('state')
        url = self.url_template.format(state=state.lower().replace(' ', ''))
        temp_file = 'tmp.jpg'
        with open(temp_file, 'wb') as f:
            f.write(requests.get(url).content)
        self.insert1(dict(key, flower_image=plt.imread(temp_file)))

In [None]:
StateFlower.populate(display_progress=True)

In [None]:
dj.Diagram(schema)

In [None]:
plt.imshow((StateFlag & {'state_code': 'WA'}).fetch1('flag_image'))
plt.axis(False);

In [None]:
StateFlag * State

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))

for ax, info in zip(axx.flatten(), 
                    (StateFlag*State).fetch(as_dict=True)):
    ax.imshow(info['flag_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))
for ax, info in zip(axx.flatten(), 
                    (StateBird*State).fetch(as_dict=True)):
    ax.imshow(info['bird_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))
for ax, info in zip(axx.flatten(), 
                    (StateFlower*State).fetch(as_dict=True)):
    ax.imshow(info['flower_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
@schema
class FlagSaturation(dj.Computed):
    definition = """
    -> StateFlag
    ---
    saturation :  float 
    """
    
    def make(self, key):
        img = (StateFlag & key).fetch1('flag_image')
        avg_color = img.mean(axis=(0,1))
        sat = avg_color.max()/avg_color.sum()
        self.insert1(dict(key, saturation=sat))
        

In [None]:
FlagSaturation.populate()

In [None]:
FlagSaturation * State * StateFlag

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))

for ax, info in zip(axx.flatten(), 
                    (StateFlag*State*FlagSaturation).fetch(as_dict=True, order_by='saturation')):
    ax.imshow(info['flag_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
dj.Diagram(schema)

In [None]:
StateFlower.populate()
StateFlag.populate()
StateBird.populate()
FlagSaturation.populate()

In [None]:
@schema
class FlagBrightness(dj.Computed):
    definition = """
    -> StateFlag
    ---
    brightness :  float 
    """
    
    def make(self, key):
        img = (StateFlag & key).fetch1('flag_image')
        self.insert1(dict(key, brightness=img.mean()))

In [None]:
FlagBrightness.populate()

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))

for ax, info in zip(axx.flatten(), 
                    (StateFlag*State*FlagBrightness).fetch(as_dict=True, order_by='brightness')):
    ax.imshow(info['flag_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
@schema
class FlagContrast(dj.Computed):
    definition = """
    -> FlagBrightness
    ---
    contrast :  float 
    """
    
    def make(self, key):
        img, brightness = (StateFlag * FlagBrightness & key).fetch1('flag_image', 'brightness')
        self.insert1(dict(key, contrast=img.mean(axis=-1).std()/brightness))

In [None]:
FlagContrast.populate(display_progress=True)

In [None]:
fig, axx = plt.subplots(5, 10, figsize=(16, 7))

for ax, info in zip(axx.flatten(), 
                    (StateFlag*State*FlagContrast).fetch(as_dict=True, order_by='contrast')):
    ax.imshow(info['flag_image'])
    ax.set_title(info['state'])
    ax.axis(False)

In [None]:
dj.Diagram(schema)